More junit 4 removal, all tests compile. FIxed parameterized test invocation. Deleted nd4j-parameter-server-status that used play

master
agibsonccc 2021-03-17 20:04:53 +09:00
parent 3c6014271e
commit e0077c38a9
306 changed files with 3239 additions and 6691 deletions

View File

@ -34,8 +34,9 @@ import java.net.URISyntaxException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Random; import java.util.Random;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
/** /**
* *

View File

@ -32,9 +32,7 @@ import org.junit.jupiter.api.Test;
import java.io.File; import java.io.File;
import java.io.OutputStream; import java.io.OutputStream;
import static junit.framework.TestCase.assertTrue; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
public class PartitionerTests extends BaseND4JTest { public class PartitionerTests extends BaseND4JTest {
@Test @Test

View File

@ -58,7 +58,7 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.Assert;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -71,7 +71,7 @@ import java.io.ObjectOutputStream;
import java.util.*; import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static junit.framework.TestCase.assertEquals;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class TestTransforms extends BaseND4JTest { public class TestTransforms extends BaseND4JTest {
@ -277,22 +277,22 @@ public class TestTransforms extends BaseND4JTest {
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS); List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
outputColumns.add(NEW_COLUMN); outputColumns.add(NEW_COLUMN);
Schema newSchema = transform.transform(schema); Schema newSchema = transform.transform(schema);
Assert.assertEquals(outputColumns, newSchema.getColumnNames()); assertEquals(outputColumns, newSchema.getColumnNames());
List<Writable> input = new ArrayList<>(); List<Writable> input = new ArrayList<>();
input.addAll(COLUMN_VALUES); input.addAll(COLUMN_VALUES);
transform.setInputSchema(schema); transform.setInputSchema(schema);
List<Writable> transformed = transform.map(input); List<Writable> transformed = transform.map(input);
Assert.assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString()); assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString());
List<Text> outputColumnValues = new ArrayList<>(COLUMN_VALUES); List<Text> outputColumnValues = new ArrayList<>(COLUMN_VALUES);
outputColumnValues.add(new Text(NEW_COLUMN_VALUE)); outputColumnValues.add(new Text(NEW_COLUMN_VALUE));
Assert.assertEquals(outputColumnValues, transformed); assertEquals(outputColumnValues, transformed);
String s = JsonMappers.getMapper().writeValueAsString(transform); String s = JsonMappers.getMapper().writeValueAsString(transform);
Transform transform2 = JsonMappers.getMapper().readValue(s, ConcatenateStringColumns.class); Transform transform2 = JsonMappers.getMapper().readValue(s, ConcatenateStringColumns.class);
Assert.assertEquals(transform, transform2); assertEquals(transform, transform2);
} }
@Test @Test
@ -309,7 +309,7 @@ public class TestTransforms extends BaseND4JTest {
transform.setInputSchema(schema); transform.setInputSchema(schema);
Schema newSchema = transform.transform(schema); Schema newSchema = transform.transform(schema);
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS); List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
Assert.assertEquals(outputColumns, newSchema.getColumnNames()); assertEquals(outputColumns, newSchema.getColumnNames());
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.LOWER); transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.LOWER);
transform.setInputSchema(schema); transform.setInputSchema(schema);
@ -320,8 +320,8 @@ public class TestTransforms extends BaseND4JTest {
output.add(new Text(TEXT_LOWER_CASE)); output.add(new Text(TEXT_LOWER_CASE));
output.add(new Text(TEXT_MIXED_CASE)); output.add(new Text(TEXT_MIXED_CASE));
List<Writable> transformed = transform.map(input); List<Writable> transformed = transform.map(input);
Assert.assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE); assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE);
Assert.assertEquals(transformed, output); assertEquals(transformed, output);
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.UPPER); transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.UPPER);
transform.setInputSchema(schema); transform.setInputSchema(schema);
@ -329,12 +329,12 @@ public class TestTransforms extends BaseND4JTest {
output.add(new Text(TEXT_UPPER_CASE)); output.add(new Text(TEXT_UPPER_CASE));
output.add(new Text(TEXT_MIXED_CASE)); output.add(new Text(TEXT_MIXED_CASE));
transformed = transform.map(input); transformed = transform.map(input);
Assert.assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE); assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE);
Assert.assertEquals(transformed, output); assertEquals(transformed, output);
String s = JsonMappers.getMapper().writeValueAsString(transform); String s = JsonMappers.getMapper().writeValueAsString(transform);
Transform transform2 = JsonMappers.getMapper().readValue(s, ChangeCaseStringTransform.class); Transform transform2 = JsonMappers.getMapper().readValue(s, ChangeCaseStringTransform.class);
Assert.assertEquals(transform, transform2); assertEquals(transform, transform2);
} }
@Test @Test
@ -1530,7 +1530,7 @@ public class TestTransforms extends BaseND4JTest {
String json = JsonMappers.getMapper().writeValueAsString(t); String json = JsonMappers.getMapper().writeValueAsString(t);
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class); Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class);
Assert.assertEquals(t, transform2); assertEquals(t, transform2);
} }
@ -1551,7 +1551,7 @@ public class TestTransforms extends BaseND4JTest {
String json = JsonMappers.getMapper().writeValueAsString(t); String json = JsonMappers.getMapper().writeValueAsString(t);
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToIndicesNDArrayTransform.class); Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToIndicesNDArrayTransform.class);
Assert.assertEquals(t, transform2); assertEquals(t, transform2);
} }

View File

@ -54,10 +54,8 @@ import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.util.*; import java.util.*;
import static java.nio.channels.Channels.newChannel; import static java.nio.channels.Channels.newChannel;
import static junit.framework.TestCase.assertTrue; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;

View File

@ -42,7 +42,7 @@ import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.datavec.api.transform.schema.Schema.Builder; import static org.datavec.api.transform.schema.Schema.Builder;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;

View File

@ -40,8 +40,9 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class NormalizationTests extends BaseSparkTest { public class NormalizationTests extends BaseSparkTest {

View File

@ -165,9 +165,8 @@
</plugin> </plugin>
<plugin> <plugin>
<artifactId>maven-surefire-plugin</artifactId> <artifactId>maven-surefire-plugin</artifactId>
<version>${maven-surefire-plugin.version}</version>
<configuration> <configuration>
<argLine> "</argLine> <argLine></argLine>
<!-- <!--
By default: Surefire will set the classpath based on the manifest. Because tests are not included By default: Surefire will set the classpath based on the manifest. Because tests are not included

View File

@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.List; import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Disabled @Disabled

View File

@ -27,12 +27,11 @@ import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static junit.framework.TestCase.assertNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Eval Json Test") @DisplayName("Eval Json Test")
class EvalJsonTest extends BaseDL4JTest { class EvalJsonTest extends BaseDL4JTest {

View File

@ -79,7 +79,6 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
return 90000L; return 90000L;
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
public void testYoloOutputLayer(CNN2DFormat format) { public void testYoloOutputLayer(CNN2DFormat format) {
@ -180,7 +179,6 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
public void yoloGradientCheckRealData(@TempDir Path testDir,CNN2DFormat format) throws Exception { public void yoloGradientCheckRealData(@TempDir Path testDir,CNN2DFormat format) throws Exception {

View File

@ -150,7 +150,6 @@ class BidirectionalTest extends BaseDL4JTest {
} }
@DisplayName("Compare Implementations Comp Graph") @DisplayName("Compare Implementations Comp Graph")
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
void compareImplementationsCompGraph(RNNFormat rnnFormat) { void compareImplementationsCompGraph(RNNFormat rnnFormat) {

View File

@ -55,7 +55,6 @@ class MaskZeroLayerTest extends BaseDL4JTest {
} }
@DisplayName("Activate") @DisplayName("Activate")
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
void activate(RNNFormat rnnDataFormat) { void activate(RNNFormat rnnDataFormat) {
@ -96,7 +95,6 @@ class MaskZeroLayerTest extends BaseDL4JTest {
@DisplayName("Test Serialization") @DisplayName("Test Serialization")
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
void testSerialization(RNNFormat rnnDataFormat) { void testSerialization(RNNFormat rnnDataFormat) {

View File

@ -109,7 +109,6 @@ public class RnnDataFormatTests extends BaseDL4JTest {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
public void testLSTM(boolean helpers, public void testLSTM(boolean helpers,

View File

@ -62,7 +62,6 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
public void testLastTimeStepVertex(RNNFormat rnnDataFormat) { public void testLastTimeStepVertex(RNNFormat rnnDataFormat) {
@ -127,7 +126,6 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
TestUtils.testModelSerialization(graph); TestUtils.testModelSerialization(graph);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
public void testMaskingAndAllMasked(RNNFormat rnnDataFormat) { public void testMaskingAndAllMasked(RNNFormat rnnDataFormat) {

View File

@ -65,7 +65,6 @@ public class TestRnnLayers extends BaseDL4JTest {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat) { public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat) {
@ -117,7 +116,6 @@ public class TestRnnLayers extends BaseDL4JTest {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat){ public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat){
@ -217,7 +215,6 @@ public class TestRnnLayers extends BaseDL4JTest {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat){ public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat){

View File

@ -55,7 +55,6 @@ public class TestSimpleRnn extends BaseDL4JTest {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
public void testSimpleRnn(RNNFormat rnnDataFormat) { public void testSimpleRnn(RNNFormat rnnDataFormat) {
@ -126,7 +125,6 @@ public class TestSimpleRnn extends BaseDL4JTest {
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
public void testBiasInit(RNNFormat rnnDataFormat) { public void testBiasInit(RNNFormat rnnDataFormat) {

View File

@ -61,7 +61,6 @@ public class TestTimeDistributed extends BaseDL4JTest {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("#params") @MethodSource("#params")
public void testTimeDistributed(RNNFormat rnnDataFormat){ public void testTimeDistributed(RNNFormat rnnDataFormat){

View File

@ -26,8 +26,8 @@ import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import oshi.json.SystemInfo; import oshi.json.SystemInfo;
import static junit.framework.TestCase.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@Disabled("AB 2019/05/24 - Failing on CI - \"Could not initialize class oshi.jna.platform.linux.Libc\" - Issue #7657") @Disabled("AB 2019/05/24 - Failing on CI - \"Could not initialize class oshi.jna.platform.linux.Libc\" - Issue #7657")
public class TestHardWareMetric extends BaseDL4JTest { public class TestHardWareMetric extends BaseDL4JTest {

View File

@ -45,9 +45,9 @@ import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import java.io.*; import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assume.assumeNotNull; import static org.junit.jupiter.api.Assumptions.*;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@ -67,11 +67,11 @@ class ModelGuesserTest extends BaseDL4JTest {
File f = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"); File f = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5");
assertTrue(f.exists()); assertTrue(f.exists());
Model guess1 = ModelGuesser.loadModelGuess(f.getAbsolutePath()); Model guess1 = ModelGuesser.loadModelGuess(f.getAbsolutePath());
assumeNotNull(guess1); assertNotNull(guess1);
f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5"); f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5");
assertTrue(f.exists()); assertTrue(f.exists());
Model guess2 = ModelGuesser.loadModelGuess(f.getAbsolutePath()); Model guess2 = ModelGuesser.loadModelGuess(f.getAbsolutePath());
assumeNotNull(guess2); assertNotNull(guess2);
} }
@Test @Test
@ -81,13 +81,13 @@ class ModelGuesserTest extends BaseDL4JTest {
assertTrue(f.exists()); assertTrue(f.exists());
try (InputStream inputStream = new FileInputStream(f)) { try (InputStream inputStream = new FileInputStream(f)) {
Model guess1 = ModelGuesser.loadModelGuess(inputStream); Model guess1 = ModelGuesser.loadModelGuess(inputStream);
assumeNotNull(guess1); assertNotNull(guess1);
} }
f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5"); f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5");
assertTrue(f.exists()); assertTrue(f.exists());
try (InputStream inputStream = new FileInputStream(f)) { try (InputStream inputStream = new FileInputStream(f)) {
Model guess1 = ModelGuesser.loadModelGuess(inputStream); Model guess1 = ModelGuesser.loadModelGuess(inputStream);
assumeNotNull(guess1); assertNotNull(guess1);
} }
} }
@ -157,7 +157,7 @@ class ModelGuesserTest extends BaseDL4JTest {
ModelSerializer.writeModel(net, tempFile, true); ModelSerializer.writeModel(net, tempFile, true);
try (InputStream inputStream = new FileInputStream(tempFile)) { try (InputStream inputStream = new FileInputStream(tempFile)) {
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream); MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream);
assumeNotNull(network); assertNotNull(network);
assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson());
assertEquals(net.params(), network.params()); assertEquals(net.params(), network.params());
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());

View File

@ -52,7 +52,7 @@ import java.util.zip.ZipFile;
import java.util.zip.ZipInputStream; import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream; import java.util.zip.ZipOutputStream;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class ModelValidatorTests extends BaseDL4JTest { public class ModelValidatorTests extends BaseDL4JTest {
@ -91,7 +91,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
ValidationResult vr2 = DL4JModelValidator.validateMultiLayerNetwork(f2); ValidationResult vr2 = DL4JModelValidator.validateMultiLayerNetwork(f2);
assertFalse(vr2.isValid()); assertFalse(vr2.isValid());
String s = vr2.getIssues().get(0); String s = vr2.getIssues().get(0);
assertTrue(s, s.contains("zip") && s.contains("corrupt")); assertTrue(s.contains("zip") && s.contains("corrupt"), s);
assertEquals("MultiLayerNetwork", vr2.getFormatType()); assertEquals("MultiLayerNetwork", vr2.getFormatType());
assertEquals(MultiLayerNetwork.class, vr2.getFormatClass()); assertEquals(MultiLayerNetwork.class, vr2.getFormatClass());
assertNotNull(vr2.getException()); assertNotNull(vr2.getException());
@ -108,7 +108,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertFalse(vr3.isValid()); assertFalse(vr3.isValid());
s = vr3.getIssues().get(0); s = vr3.getIssues().get(0);
assertEquals(1, vr3.getIssues().size()); assertEquals(1, vr3.getIssues().size());
assertTrue(s, s.contains("missing") && s.contains("configuration")); assertTrue(s.contains("missing") && s.contains("configuration"), s);
assertEquals("MultiLayerNetwork", vr3.getFormatType()); assertEquals("MultiLayerNetwork", vr3.getFormatType());
assertEquals(MultiLayerNetwork.class, vr3.getFormatClass()); assertEquals(MultiLayerNetwork.class, vr3.getFormatClass());
assertNull(vr3.getException()); assertNull(vr3.getException());
@ -126,7 +126,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertFalse(vr4.isValid()); assertFalse(vr4.isValid());
s = vr4.getIssues().get(0); s = vr4.getIssues().get(0);
assertEquals(1, vr4.getIssues().size()); assertEquals(1, vr4.getIssues().size());
assertTrue(s, s.contains("missing") && s.contains("coefficients")); assertTrue(s.contains("missing") && s.contains("coefficients"), s);
assertEquals("MultiLayerNetwork", vr4.getFormatType()); assertEquals("MultiLayerNetwork", vr4.getFormatType());
assertEquals(MultiLayerNetwork.class, vr4.getFormatClass()); assertEquals(MultiLayerNetwork.class, vr4.getFormatClass());
assertNull(vr4.getException()); assertNull(vr4.getException());
@ -169,7 +169,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertFalse(vr6.isValid()); assertFalse(vr6.isValid());
s = vr6.getIssues().get(0); s = vr6.getIssues().get(0);
assertEquals(1, vr6.getIssues().size()); assertEquals(1, vr6.getIssues().size());
assertTrue(s, s.contains("JSON") && s.contains("valid") && s.contains("MultiLayerConfiguration")); assertTrue(s.contains("JSON") && s.contains("valid") && s.contains("MultiLayerConfiguration"), s);
assertEquals("MultiLayerNetwork", vr6.getFormatType()); assertEquals("MultiLayerNetwork", vr6.getFormatType());
assertEquals(MultiLayerNetwork.class, vr6.getFormatClass()); assertEquals(MultiLayerNetwork.class, vr6.getFormatClass());
assertNotNull(vr6.getException()); assertNotNull(vr6.getException());
@ -209,7 +209,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
ValidationResult vr2 = DL4JModelValidator.validateComputationGraph(f2); ValidationResult vr2 = DL4JModelValidator.validateComputationGraph(f2);
assertFalse(vr2.isValid()); assertFalse(vr2.isValid());
String s = vr2.getIssues().get(0); String s = vr2.getIssues().get(0);
assertTrue(s, s.contains("zip") && s.contains("corrupt")); assertTrue(s.contains("zip") && s.contains("corrupt"), s);
assertEquals("ComputationGraph", vr2.getFormatType()); assertEquals("ComputationGraph", vr2.getFormatType());
assertEquals(ComputationGraph.class, vr2.getFormatClass()); assertEquals(ComputationGraph.class, vr2.getFormatClass());
assertNotNull(vr2.getException()); assertNotNull(vr2.getException());
@ -226,7 +226,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertFalse(vr3.isValid()); assertFalse(vr3.isValid());
s = vr3.getIssues().get(0); s = vr3.getIssues().get(0);
assertEquals(1, vr3.getIssues().size()); assertEquals(1, vr3.getIssues().size());
assertTrue(s, s.contains("missing") && s.contains("configuration")); assertTrue(s.contains("missing") && s.contains("configuration"), s);
assertEquals("ComputationGraph", vr3.getFormatType()); assertEquals("ComputationGraph", vr3.getFormatType());
assertEquals(ComputationGraph.class, vr3.getFormatClass()); assertEquals(ComputationGraph.class, vr3.getFormatClass());
assertNull(vr3.getException()); assertNull(vr3.getException());
@ -244,7 +244,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertFalse(vr4.isValid()); assertFalse(vr4.isValid());
s = vr4.getIssues().get(0); s = vr4.getIssues().get(0);
assertEquals(1, vr4.getIssues().size()); assertEquals(1, vr4.getIssues().size());
assertTrue(s, s.contains("missing") && s.contains("coefficients")); assertTrue(s.contains("missing") && s.contains("coefficients"), s);
assertEquals("ComputationGraph", vr4.getFormatType()); assertEquals("ComputationGraph", vr4.getFormatType());
assertEquals(ComputationGraph.class, vr4.getFormatClass()); assertEquals(ComputationGraph.class, vr4.getFormatClass());
assertNull(vr4.getException()); assertNull(vr4.getException());
@ -287,7 +287,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
assertFalse(vr6.isValid()); assertFalse(vr6.isValid());
s = vr6.getIssues().get(0); s = vr6.getIssues().get(0);
assertEquals(1, vr6.getIssues().size()); assertEquals(1, vr6.getIssues().size());
assertTrue(s, s.contains("JSON") && s.contains("valid") && s.contains("ComputationGraphConfiguration")); assertTrue(s.contains("JSON") && s.contains("valid") && s.contains("ComputationGraphConfiguration"), s);
assertEquals("ComputationGraph", vr6.getFormatType()); assertEquals("ComputationGraph", vr6.getFormatType());
assertEquals(ComputationGraph.class, vr6.getFormatClass()); assertEquals(ComputationGraph.class, vr6.getFormatClass());
assertNotNull(vr6.getException()); assertNotNull(vr6.getException());

View File

@ -24,9 +24,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.core.util.UIDProvider; import org.deeplearning4j.core.util.UIDProvider;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static junit.framework.TestCase.assertTrue; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
public class TestUIDProvider extends BaseDL4JTest { public class TestUIDProvider extends BaseDL4JTest {

View File

@ -89,6 +89,10 @@
<groupId>org.junit.jupiter</groupId> <groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId> <artifactId>junit-jupiter-engine</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId> <artifactId>deeplearning4j-common-tests</artifactId>
@ -120,113 +124,4 @@
</dependency> </dependency>
</dependencies> </dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<inherited>true</inherited>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
<configuration>
<environmentVariables>
</environmentVariables>
<testSourceDirectory>src/test/java</testSourceDirectory>
<includes>
<include>*.java</include>
<include>**/*.java</include>
<include>**/Test*.java</include>
<include>**/*Test.java</include>
<include>**/*TestCase.java</include>
</includes>
<junitArtifactName>org.junit.jupiter:junit-jupiter</junitArtifactName>
<systemPropertyVariables>
<org.nd4j.linalg.defaultbackend>
org.nd4j.linalg.cpu.nativecpu.CpuBackend
</org.nd4j.linalg.defaultbackend>
<org.nd4j.linalg.tests.backendstorun>
org.nd4j.linalg.cpu.nativecpu.CpuBackend
</org.nd4j.linalg.tests.backendstorun>
</systemPropertyVariables>
<!--
Maximum heap size was set to 8g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough.
For testing large zoo models, this may not be enough (so comment it out).
-->
<argLine></argLine>
</configuration>
</plugin>
</plugins>
</build>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-11.0</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<dependencies>
<dependency>
<groupId>org.apache.maven.surefire</groupId>
<artifactId>surefire-junit47</artifactId>
<version>2.19.1</version>
</dependency>
</dependencies>
<configuration>
<environmentVariables>
</environmentVariables>
<testSourceDirectory>src/test/java</testSourceDirectory>
<includes>
<include>*.java</include>
<include>**/*.java</include>
<include>**/Test*.java</include>
<include>**/*Test.java</include>
<include>**/*TestCase.java</include>
</includes>
<junitArtifactName>org.junit.jupiter:junit-jupiter</junitArtifactName>
<systemPropertyVariables>
<org.nd4j.linalg.defaultbackend>
org.nd4j.linalg.jcublas.JCublasBackend
</org.nd4j.linalg.defaultbackend>
<org.nd4j.linalg.tests.backendstorun>
org.nd4j.linalg.jcublas.JCublasBackend
</org.nd4j.linalg.tests.backendstorun>
</systemPropertyVariables>
<!--
Maximum heap size was set to 6g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough.
-->
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project> </project>

View File

@ -107,17 +107,18 @@
</dependencyManagement> </dependencyManagement>
<dependencies> <dependencies>
<!-- For unit tests -->
<dependency> <dependency>
<groupId>org.junit.jupiter</groupId> <groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId> <artifactId>junit-jupiter-api</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.junit.jupiter</groupId> <groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId> <artifactId>junit-jupiter-engine</artifactId>
<version>${junit.version}</version> </dependency>
<scope>test</scope> <dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.projectlombok</groupId> <groupId>org.projectlombok</groupId>
@ -230,7 +231,6 @@
<plugins> <plugins>
<plugin> <plugin>
<artifactId>maven-surefire-plugin</artifactId> <artifactId>maven-surefire-plugin</artifactId>
<version>${maven-surefire-plugin.version}</version>
<inherited>true</inherited> <inherited>true</inherited>
<configuration> <configuration>
<!-- <!--
@ -250,6 +250,13 @@
<include>**/*.java</include> <include>**/*.java</include>
</includes> </includes>
</configuration> </configuration>
<dependencies>
<dependency>
<groupId>org.apache.maven.surefire</groupId>
<artifactId>surefire-junit-platform</artifactId>
<version>${maven-surefire-plugin.version}</version>
</dependency>
</dependencies>
</plugin> </plugin>
<plugin> <plugin>
<groupId>org.eclipse.m2e</groupId> <groupId>org.eclipse.m2e</groupId>
@ -316,6 +323,21 @@
<artifactId>nd4j-native</artifactId> <artifactId>nd4j-native</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>${junit.version}</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<version>${junit.version}</version>
</dependency>
<dependency>
<groupId>org.apache.maven.surefire</groupId>
<artifactId>surefire-junit-platform</artifactId>
<version>${maven-surefire-plugin.version}</version>
</dependency>
</dependencies> </dependencies>
<configuration> <configuration>
<environmentVariables> <environmentVariables>
@ -344,7 +366,7 @@
For testing large zoo models, this may not be enough (so comment it out). For testing large zoo models, this may not be enough (so comment it out).
--> -->
<argLine> "</argLine> <argLine></argLine>
</configuration> </configuration>
</plugin> </plugin>
</plugins> </plugins>
@ -376,13 +398,7 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId> <artifactId>maven-surefire-plugin</artifactId>
<dependencies> <version>${maven-surefire-plugin.version}</version>
<dependency>
<groupId>org.junit</groupId>
<artifactId>surefire-junit5</artifactId>
<version>5.0.0-ALPHA</version>
</dependency>
</dependencies>
<configuration> <configuration>
<environmentVariables> <environmentVariables>
</environmentVariables> </environmentVariables>
@ -409,6 +425,7 @@
--> -->
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine> <argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
</configuration> </configuration>
</plugin> </plugin>
</plugins> </plugins>
</build> </build>

View File

@ -1259,7 +1259,7 @@ char* DoubleToBuffer(double value, char* buffer) {
// DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all // DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
// platforms these days. Just in case some system exists where DBL_DIG // platforms these days. Just in case some system exists where DBL_DIG
// is significantly larger -- and risks overflowing our buffer -- we have // is significantly larger -- and risks overflowing our buffer -- we have
// this assert. // this
GOOGLE_COMPILE_ASSERT(DBL_DIG < 20, DBL_DIG_is_too_big); GOOGLE_COMPILE_ASSERT(DBL_DIG < 20, DBL_DIG_is_too_big);
if (value == std::numeric_limits<double>::infinity()) { if (value == std::numeric_limits<double>::infinity()) {
@ -1377,7 +1377,7 @@ char* FloatToBuffer(float value, char* buffer) {
// FLT_DIG is 6 for IEEE-754 floats, which are used on almost all // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
// platforms these days. Just in case some system exists where FLT_DIG // platforms these days. Just in case some system exists where FLT_DIG
// is significantly larger -- and risks overflowing our buffer -- we have // is significantly larger -- and risks overflowing our buffer -- we have
// this assert. // this
GOOGLE_COMPILE_ASSERT(FLT_DIG < 10, FLT_DIG_is_too_big); GOOGLE_COMPILE_ASSERT(FLT_DIG < 10, FLT_DIG_is_too_big);
if (value == std::numeric_limits<double>::infinity()) { if (value == std::numeric_limits<double>::infinity()) {

View File

@ -182,7 +182,7 @@ size_t DoubleToBuffer(double value, char* buffer) {
// DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all // DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
// platforms these days. Just in case some system exists where DBL_DIG // platforms these days. Just in case some system exists where DBL_DIG
// is significantly larger -- and risks overflowing our buffer -- we have // is significantly larger -- and risks overflowing our buffer -- we have
// this assert. // this
static_assert(DBL_DIG < 20, "DBL_DIG is too big"); static_assert(DBL_DIG < 20, "DBL_DIG is too big");
if (std::abs(value) <= kDoublePrecisionCheckMax) { if (std::abs(value) <= kDoublePrecisionCheckMax) {
@ -363,7 +363,7 @@ size_t FloatToBuffer(float value, char* buffer) {
// FLT_DIG is 6 for IEEE-754 floats, which are used on almost all // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
// platforms these days. Just in case some system exists where FLT_DIG // platforms these days. Just in case some system exists where FLT_DIG
// is significantly larger -- and risks overflowing our buffer -- we have // is significantly larger -- and risks overflowing our buffer -- we have
// this assert. // this
static_assert(FLT_DIG < 10, "FLT_DIG is too big"); static_assert(FLT_DIG < 10, "FLT_DIG is too big");
int snprintf_result = int snprintf_result =

View File

@ -139,13 +139,6 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId> <artifactId>maven-surefire-plugin</artifactId>
<dependencies>
<dependency>
<groupId>org.apache.maven.surefire</groupId>
<artifactId>surefire-junit47</artifactId>
<version>2.19.1</version>
</dependency>
</dependencies>
<configuration> <configuration>
<environmentVariables> <environmentVariables>
<LD_LIBRARY_PATH> <LD_LIBRARY_PATH>

View File

@ -81,7 +81,6 @@
<build> <build>
<plugins> <plugins>
<!-- Skip execution of Javadoc since some versions run out of memory --> <!-- Skip execution of Javadoc since some versions run out of memory -->
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
@ -124,7 +123,6 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId> <artifactId>maven-surefire-plugin</artifactId>
<version>2.19.1</version>
<configuration> <configuration>
<environmentVariables> <environmentVariables>
<LD_LIBRARY_PATH>${env.LD_LIBRARY_PATH}:${user.dir}</LD_LIBRARY_PATH> <LD_LIBRARY_PATH>${env.LD_LIBRARY_PATH}:${user.dir}</LD_LIBRARY_PATH>

View File

@ -20,21 +20,19 @@
package org.nd4j; package org.nd4j;
import org.junit.AfterClass; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.nd4j.autodiff.opvalidation.*; import org.nd4j.autodiff.opvalidation.*;
import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.OpValidation;
//import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff; //import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assume.assumeFalse; import static org.junit.jupiter.api.Assumptions.assumeFalse;
@RunWith(Suite.class)
@Suite.SuiteClasses({ /*@Suite.SuiteClasses({
//Note: these will be run as part of the suite only, and will NOT be run again separately //Note: these will be run as part of the suite only, and will NOT be run again separately
LayerOpValidation.class, LayerOpValidation.class,
LossOpValidation.class, LossOpValidation.class,
@ -48,7 +46,7 @@ import static org.junit.Assume.assumeFalse;
//TF import tests //TF import tests
//TFGraphTestAllSameDiff.class //TFGraphTestAllSameDiff.class
//TFGraphTestAllLibnd4j.class //TFGraphTestAllLibnd4j.class
}) })*/
//IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test" //IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test"
// With it ignored here, the individual tests will run outside (i.e., separately/independently) of the suite in both "mvn test" and IntelliJ // With it ignored here, the individual tests will run outside (i.e., separately/independently) of the suite in both "mvn test" and IntelliJ
@Disabled @Disabled
@ -84,7 +82,7 @@ public class OpValidationSuite {
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
} }
@AfterClass @AfterEach
public static void afterClass() { public static void afterClass() {
Nd4j.setDataType(initialType); Nd4j.setDataType(initialType);

View File

@ -145,9 +145,8 @@ public class TestOpMapping extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOpMappingCoverage() throws Exception { public void testOpMappingCoverage() throws Exception {
Map<String, DifferentialFunction> opNameMapping = ImportClassMapping.getOpNameMapping(); Map<String, DifferentialFunction> opNameMapping = ImportClassMapping.getOpNameMapping();
Map<String, DifferentialFunction> tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions(); Map<String, DifferentialFunction> tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions();
@ -197,9 +196,8 @@ public class TestOpMapping extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOpsInNamespace(Nd4jBackend backend) throws Exception { public void testOpsInNamespace(Nd4jBackend backend) throws Exception {
//Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't //Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't
// want to add to a namespace for some reason) // want to add to a namespace for some reason)
@ -361,7 +359,7 @@ public class TestOpMapping extends BaseNd4jTestWithBackends {
@Test @Test
@Disabled @Disabled
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void generateOpClassList(Nd4jBackend backend) throws Exception{ public void generateOpClassList(Nd4jBackend backend) throws Exception{
Reflections reflections = new Reflections("org.nd4j"); Reflections reflections = new Reflections("org.nd4j");
Set<Class<? extends DifferentialFunction>> subTypes = reflections.getSubTypesOf(DifferentialFunction.class); Set<Class<? extends DifferentialFunction>> subTypes = reflections.getSubTypesOf(DifferentialFunction.class);

View File

@ -55,9 +55,8 @@ public class TestSessions extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInferenceSessionBasic(Nd4jBackend backend) { public void testInferenceSessionBasic(Nd4jBackend backend) {
//So far: trivial test to check execution order //So far: trivial test to check execution order
@ -89,9 +88,8 @@ public class TestSessions extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInferenceSessionBasic2(Nd4jBackend backend) { public void testInferenceSessionBasic2(Nd4jBackend backend) {
//So far: trivial test to check execution order //So far: trivial test to check execution order
@ -127,9 +125,8 @@ public class TestSessions extends BaseNd4jTestWithBackends {
assertEquals(dExp, outMap.get("d")); assertEquals(dExp, outMap.get("d"));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMergeSimple(Nd4jBackend backend) { public void testMergeSimple(Nd4jBackend backend) {
//This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available... //This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available...
@ -165,9 +162,8 @@ public class TestSessions extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSwitchSimple(Nd4jBackend backend) { public void testSwitchSimple(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();

View File

@ -34,7 +34,6 @@ import org.nd4j.common.primitives.Pair;
import java.util.Collections; import java.util.Collections;
import static junit.framework.TestCase.assertNotNull;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class TestDependencyTracker extends BaseNd4jTestWithBackends { public class TestDependencyTracker extends BaseNd4jTestWithBackends {
@ -45,9 +44,8 @@ public class TestDependencyTracker extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimple(Nd4jBackend backend){ public void testSimple(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>(); DependencyTracker<String,String> dt = new DependencyTracker<>();
@ -94,9 +92,8 @@ public class TestDependencyTracker extends BaseNd4jTestWithBackends {
assertTrue(dt.isEmpty()); assertTrue(dt.isEmpty());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSatisfiedBeforeAdd(Nd4jBackend backend){ public void testSatisfiedBeforeAdd(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>(); DependencyTracker<String,String> dt = new DependencyTracker<>();
@ -135,9 +132,8 @@ public class TestDependencyTracker extends BaseNd4jTestWithBackends {
assertFalse(dt.hasNewAllSatisfied()); assertFalse(dt.hasNewAllSatisfied());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMarkUnsatisfied(Nd4jBackend backend){ public void testMarkUnsatisfied(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>(); DependencyTracker<String,String> dt = new DependencyTracker<>();
@ -169,9 +165,8 @@ public class TestDependencyTracker extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIdentityDependencyTracker(){ public void testIdentityDependencyTracker(){
IdentityDependencyTracker<INDArray, String> dt = new IdentityDependencyTracker<>(); IdentityDependencyTracker<INDArray, String> dt = new IdentityDependencyTracker<>();
assertTrue(dt.isEmpty()); assertTrue(dt.isEmpty());

View File

@ -41,9 +41,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class ActivationGradChecks extends BaseOpValidation { public class ActivationGradChecks extends BaseOpValidation {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testActivationGradientCheck1(Nd4jBackend backend) { public void testActivationGradientCheck1(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -61,9 +60,8 @@ public class ActivationGradChecks extends BaseOpValidation {
assertTrue(ok); assertTrue(ok);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testActivationGradientCheck2(Nd4jBackend backend) { public void testActivationGradientCheck2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();

View File

@ -73,9 +73,8 @@ public class LayerOpValidation extends BaseOpValidation {
return 90000L; return 90000L;
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testXwPlusB(Nd4jBackend backend) { public void testXwPlusB(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -109,9 +108,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReluLayer(Nd4jBackend backend) { public void testReluLayer(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -139,9 +137,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBiasAdd(Nd4jBackend backend) { public void testBiasAdd(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -165,9 +162,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv2d(Nd4jBackend backend) { public void testConv2d(Nd4jBackend backend) {
//avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling //avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling
//Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d //Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d
@ -307,9 +303,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLrn2d(Nd4jBackend backend) { public void testLrn2d(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -350,9 +345,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIm2Col(Nd4jBackend backend) { public void testIm2Col(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873 //OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -391,9 +385,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOutputShape(Nd4jBackend backend) { public void testOutputShape(Nd4jBackend backend) {
long[] inSize = {1, 8, 8, 3}; long[] inSize = {1, 8, 8, 3};
@ -443,9 +436,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAvgPool(Nd4jBackend backend) { public void testAvgPool(Nd4jBackend backend) {
long[] inSize = {1, 8, 8, 3}; //NHWC long[] inSize = {1, 8, 8, 3}; //NHWC
@ -488,9 +480,8 @@ public class LayerOpValidation extends BaseOpValidation {
return new int[]{in[0], in[2], in[3], in[4], in[1]}; return new int[]{in[0], in[2], in[3], in[4], in[1]};
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv3d(Nd4jBackend backend) { public void testConv3d(Nd4jBackend backend) {
//Pooling3d, Conv3D, batch norm //Pooling3d, Conv3D, batch norm
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -592,9 +583,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDepthWiseConv2dBasic(Nd4jBackend backend) { public void testDepthWiseConv2dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int depthWise = 4; int depthWise = 4;
@ -633,9 +623,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertArrayEquals(new long[]{mb, depthWise * nIn, 27, 27}, outShape); assertArrayEquals(new long[]{mb, depthWise * nIn, 27, 27}, outShape);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSeparableConv2dBasic(Nd4jBackend backend) { public void testSeparableConv2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 2; int nIn = 2;
@ -691,9 +680,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDeconv2dBasic(Nd4jBackend backend) { public void testDeconv2dBasic(Nd4jBackend backend) {
int nIn = 2; int nIn = 2;
int nOut = 3; int nOut = 3;
@ -737,9 +725,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv2dBasic(Nd4jBackend backend) { public void testConv2dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
@ -780,9 +767,8 @@ public class LayerOpValidation extends BaseOpValidation {
// sd.execBackwards(); // TODO: test failing here // sd.execBackwards(); // TODO: test failing here
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMaxPoolingArgMax(Nd4jBackend backend) { public void testMaxPoolingArgMax(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
@ -811,9 +797,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertArrayEquals(inArr.shape(), results[1].eval().shape()); assertArrayEquals(inArr.shape(), results[1].eval().shape());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMaxPooling2dBasic(Nd4jBackend backend) { public void testMaxPooling2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
@ -871,9 +856,8 @@ public class LayerOpValidation extends BaseOpValidation {
return max; return max;
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAvgPooling2dBasic(Nd4jBackend backend) { public void testAvgPooling2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
@ -922,9 +906,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAvgPooling3dBasic(Nd4jBackend backend) { public void testAvgPooling3dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
@ -961,9 +944,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMaxPooling3dBasic(Nd4jBackend backend) { public void testMaxPooling3dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
@ -1001,9 +983,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv1dBasic(Nd4jBackend backend) { public void testConv1dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
@ -1038,9 +1019,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv1dCausal(Nd4jBackend backend) { public void testConv1dCausal(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
@ -1089,9 +1069,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv1dForward(Nd4jBackend backend) { public void testConv1dForward(Nd4jBackend backend) {
int nIn = 2; int nIn = 2;
int nOut = 1; int nOut = 1;
@ -1134,9 +1113,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv3dBasic(Nd4jBackend backend) { public void testConv3dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
@ -1182,9 +1160,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDeConv3dBasic(Nd4jBackend backend) { public void testDeConv3dBasic(Nd4jBackend backend) {
int nIn = 4; int nIn = 4;
int nOut = 3; int nOut = 3;
@ -1229,9 +1206,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLayerNorm(Nd4jBackend backend) { public void testLayerNorm(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
@ -1256,9 +1232,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLayerNorm4d(Nd4jBackend backend) { public void testLayerNorm4d(Nd4jBackend backend) {
int mb = 3; int mb = 3;
int ch = 4; int ch = 4;
@ -1290,9 +1265,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLayerNormOP(Nd4jBackend backend) { public void testLayerNormOP(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
@ -1308,9 +1282,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertEquals(res, output); assertEquals(res, output);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLayerNormNoBias(Nd4jBackend backend) { public void testLayerNormNoBias(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
@ -1333,9 +1306,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err, err); assertNull(err, err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLayerNormOPNoBias(Nd4jBackend backend) { public void testLayerNormOPNoBias(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
@ -1350,9 +1322,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertEquals(res, output); assertEquals(res, output);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLayerNormNoDeviation(Nd4jBackend backend) { public void testLayerNormNoDeviation(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
@ -1467,9 +1438,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLayerNormMixedOrders(Nd4jBackend backend) { public void testLayerNormMixedOrders(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f'); INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
@ -1516,9 +1486,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertEquals(outCC, outFC); //Fails here assertEquals(outCC, outFC); //Fails here
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBiasAdd_nchw_nhwc(Nd4jBackend backend) { public void testBiasAdd_nchw_nhwc(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1549,9 +1518,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDepthwiseConv2D(){ public void testDepthwiseConv2D(){
int bS = 10; int bS = 10;
@ -1589,9 +1557,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void LSTMLayerTestCase1(Nd4jBackend backend) { public void LSTMLayerTestCase1(Nd4jBackend backend) {
int bS = 5; int bS = 5;
@ -1666,9 +1633,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void LSTMLayerTestCase2(Nd4jBackend backend) { public void LSTMLayerTestCase2(Nd4jBackend backend) {
int bS = 5; int bS = 5;
int nIn = 3; int nIn = 3;
@ -1726,9 +1692,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void LSTMLayerTestCase3(Nd4jBackend backend) { public void LSTMLayerTestCase3(Nd4jBackend backend) {
int bS = 5; int bS = 5;
int nIn = 3; int nIn = 3;
@ -1789,9 +1754,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void GRUTestCase(Nd4jBackend backend) { public void GRUTestCase(Nd4jBackend backend) {
int bS = 5; int bS = 5;
int nIn = 4; int nIn = 4;

View File

@ -55,9 +55,8 @@ public class LossOpValidation extends BaseOpValidation {
// All tested Loss Ops have backprop at the moment 2019/01/30 // All tested Loss Ops have backprop at the moment 2019/01/30
public static final Set<String> NO_BP_YET = new HashSet<>(); public static final Set<String> NO_BP_YET = new HashSet<>();
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLoss2d(Nd4jBackend backend) { public void testLoss2d(Nd4jBackend backend) {
final List<String> oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax"); final List<String> oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax");
@ -369,9 +368,8 @@ public class LossOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCosineDistance(){ public void testCosineDistance(){
INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}}); INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}});
INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}}); INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}});
@ -389,9 +387,8 @@ public class LossOpValidation extends BaseOpValidation {
assertEquals(exp, out); assertEquals(exp, out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testL2Loss(){ public void testL2Loss(){
for( int rank=0; rank<=3; rank++ ){ for( int rank=0; rank<=3; rank++ ){
@ -433,9 +430,8 @@ public class LossOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNonZeroResult(Nd4jBackend backend) { public void testNonZeroResult(Nd4jBackend backend) {
INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5); INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5);
INDArray w = Nd4j.scalar(1.0); INDArray w = Nd4j.scalar(1.0);
@ -493,9 +489,8 @@ public class LossOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void TestStdLossMixedDataType(){ public void TestStdLossMixedDataType(){
// Default Data Type in this test suite is Double. // Default Data Type in this test suite is Double.
// This test used to throw an Exception that we have mixed data types. // This test used to throw an Exception that we have mixed data types.

View File

@ -74,17 +74,17 @@ import org.nd4j.common.util.ArrayUtil;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.Assume.assumeNotNull; import static org.junit.jupiter.api.Assumptions.*;
@Slf4j @Slf4j
public class MiscOpValidation extends BaseOpValidation { public class MiscOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGradientAutoBroadcast1(Nd4jBackend backend) { public void testGradientAutoBroadcast1(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -171,9 +171,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),"Failed: " + failed); assertEquals(0, failed.size(),"Failed: " + failed);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGradientAutoBroadcast2(Nd4jBackend backend) { public void testGradientAutoBroadcast2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -262,9 +261,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),"Failed: " + failed); assertEquals(0, failed.size(),"Failed: " + failed);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGradientAutoBroadcast3(Nd4jBackend backend) { public void testGradientAutoBroadcast3(Nd4jBackend backend) {
//These tests: output size > input sizes //These tests: output size > input sizes
@ -372,9 +370,8 @@ public class MiscOpValidation extends BaseOpValidation {
return Long.MAX_VALUE; return Long.MAX_VALUE;
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScatterOpGradients(Nd4jBackend backend) { public void testScatterOpGradients(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -476,9 +473,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScatterUpdate(){ public void testScatterUpdate(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3); INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3);
INDArray updates = Nd4j.create(new float[][]{ INDArray updates = Nd4j.create(new float[][]{
@ -499,9 +495,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(exp, out); assertEquals(exp, out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGatherGradient(Nd4jBackend backend) { public void testGatherGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -552,9 +547,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTrace(){ public void testTrace(){
//TODO need to work out how to handle shape_op for scalars... //TODO need to work out how to handle shape_op for scalars...
//OpValidationSuite.ignoreFailing(); //OpValidationSuite.ignoreFailing();
@ -579,9 +573,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTensorGradTensorMmul(Nd4jBackend backend) { public void testTensorGradTensorMmul(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
@ -603,9 +596,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMulGradient(Nd4jBackend backend) { public void testMulGradient(Nd4jBackend backend) {
INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
@ -670,9 +662,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMmulGradientManual(Nd4jBackend backend) { public void testMmulGradientManual(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
@ -689,14 +680,14 @@ public class MiscOpValidation extends BaseOpValidation {
}, inputs); }, inputs);
assumeNotNull(sameDiff.getFunction("mmulGradient").getFunction("grad")); assertNotNull(sameDiff.getFunction("mmulGradient").getFunction("grad"));
assumeNotNull(sameDiff.getFunction("mmulGradient").grad("x")); assertNotNull(sameDiff.getFunction("mmulGradient").grad("x"));
assumeNotNull(sameDiff.getFunction("mmulGradient").grad("y")); assertNotNull(sameDiff.getFunction("mmulGradient").grad("y"));
SDVariable gradWrtX = sameDiff.getFunction("mmulGradient").grad("x"); SDVariable gradWrtX = sameDiff.getFunction("mmulGradient").grad("x");
SDVariable gradWrtY = sameDiff.getFunction("mmulGradient").grad("y"); SDVariable gradWrtY = sameDiff.getFunction("mmulGradient").grad("y");
assumeNotNull(gradWrtX.getArr()); assertNotNull(gradWrtX.getArr());
assumeNotNull(gradWrtY.getArr()); assertNotNull(gradWrtY.getArr());
INDArray xGradAssertion = Nd4j.create(new double[][]{ INDArray xGradAssertion = Nd4j.create(new double[][]{
@ -713,9 +704,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(yGradAssertion, gradWrtY.getArr()); assertEquals(yGradAssertion, gradWrtY.getArr());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMmulGradients(){ public void testMmulGradients(){
int[] aShape = new int[]{2,3}; int[] aShape = new int[]{2,3};
int[] bShape = new int[]{3,4}; int[] bShape = new int[]{3,4};
@ -766,9 +756,8 @@ public class MiscOpValidation extends BaseOpValidation {
return new int[]{orig[1], orig[0]}; return new int[]{orig[1], orig[0]};
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBatchMmulBasic(Nd4jBackend backend) { public void testBatchMmulBasic(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873 OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873
int M = 5; int M = 5;
@ -793,9 +782,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMmulWithTranspose(Nd4jBackend backend) { public void testMmulWithTranspose(Nd4jBackend backend) {
//Here: [x,3]^T * [x,4] = [3,4] //Here: [x,3]^T * [x,4] = [3,4]
@ -832,9 +820,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMmulOutputSizeCalculation(){ public void testMmulOutputSizeCalculation(){
//[3,2] x [2,4] with result transpose: output shape [4,3] //[3,2] x [2,4] with result transpose: output shape [4,3]
INDArray a = Nd4j.create(3,2); INDArray a = Nd4j.create(3,2);
@ -866,9 +853,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testFillOp(){ public void testFillOp(){
INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT); INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT);
@ -882,9 +868,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testClipByNorm(){ public void testClipByNorm(){
//Expected: if array.norm2(1) is less than 1.0, not modified //Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -916,9 +901,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(exp, norm2_1b); assertEquals(exp, norm2_1b);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testClipByNorm2(){ public void testClipByNorm2(){
//Expected: if array.norm2(1) is less than 1.0, not modified //Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -961,9 +945,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testClipByNorm1(){ public void testClipByNorm1(){
//Expected: if array.norm2(1) is less than 1.0, not modified //Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -1003,9 +986,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testClipByNorm0(){ public void testClipByNorm0(){
//Expected: if array.norm2(0) is less than 1.0, not modified //Expected: if array.norm2(0) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -1034,9 +1016,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(OpValidation.validate(op)); assertNull(OpValidation.validate(op));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCumSum(){ public void testCumSum(){
List<String> failing = new ArrayList<>(); List<String> failing = new ArrayList<>();
@ -1101,9 +1082,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCumProd(){ public void testCumProd(){
List<String> failing = new ArrayList<>(); List<String> failing = new ArrayList<>();
@ -1171,9 +1151,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(0, failing.size(),failing.toString()); assertEquals(0, failing.size(),failing.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOneHot1(){ public void testOneHot1(){
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -1203,9 +1182,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals( 0, failed.size(),failed.toString()); assertEquals( 0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOneHotOp(){ public void testOneHotOp(){
//https://www.tensorflow.org/api_docs/python/tf/one_hot //https://www.tensorflow.org/api_docs/python/tf/one_hot
//https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp //https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp
@ -1219,9 +1197,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOneHot2(Nd4jBackend backend) { public void testOneHot2(Nd4jBackend backend) {
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1); INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
@ -1241,9 +1218,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOneHot4(Nd4jBackend backend) { public void testOneHot4(Nd4jBackend backend) {
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1); INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
@ -1263,9 +1239,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOneHot3(Nd4jBackend backend) { public void testOneHot3(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6872 //https://github.com/deeplearning4j/deeplearning4j/issues/6872
@ -1300,9 +1275,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLinspace(){ public void testLinspace(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10); SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10);
@ -1315,9 +1289,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLinspace2(){ public void testLinspace2(){
OpValidationSuite.ignoreFailing(); //TODO 2019/01/18 OpValidationSuite.ignoreFailing(); //TODO 2019/01/18
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1331,9 +1304,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testShapeFn(Nd4jBackend backend) { public void testShapeFn(Nd4jBackend backend) {
INDArray in = Nd4j.create(new long[]{1, 2}); INDArray in = Nd4j.create(new long[]{1, 2});
@ -1347,9 +1319,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertArrayEquals(new long[]{2}, shapes.get(0).getShape()); assertArrayEquals(new long[]{2}, shapes.get(0).getShape());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testShapeFn2(Nd4jBackend backend) { public void testShapeFn2(Nd4jBackend backend) {
INDArray i = Nd4j.create(1,3); INDArray i = Nd4j.create(1,3);
@ -1362,9 +1333,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMergeRank1(){ public void testMergeRank1(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5)); SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5));
@ -1382,9 +1352,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(1, inGrad.rank()); assertEquals(1, inGrad.rank());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDiagPart(Nd4jBackend backend) { public void testDiagPart(Nd4jBackend backend) {
INDArray i = Nd4j.create(5,5); INDArray i = Nd4j.create(5,5);
@ -1396,9 +1365,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(1, out.rank()); assertEquals(1, out.rank());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDiagShapeFn(Nd4jBackend backend) { public void testDiagShapeFn(Nd4jBackend backend) {
INDArray i = Nd4j.create(5,5); INDArray i = Nd4j.create(5,5);
@ -1411,9 +1379,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testZerosOnesLike(){ public void testZerosOnesLike(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1455,9 +1422,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testZerosLikeOp(){ public void testZerosLikeOp(){
INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0); INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0);
@ -1472,9 +1438,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConfusionMatrix(){ public void testConfusionMatrix(){
DataType dt = DataType.DOUBLE; DataType dt = DataType.DOUBLE;
@ -1510,9 +1475,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIsNonDecreasingIsStrictlyIncr(){ public void testIsNonDecreasingIsStrictlyIncr(){
List<long[]> shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3}); List<long[]> shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3});
@ -1575,9 +1539,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals( 0, failed.size(),failed.toString()); assertEquals( 0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testExtractImagePatches(){ public void testExtractImagePatches(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -1624,9 +1587,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(exp, out); assertEquals(exp, out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentProdBpSimple(){ public void testSegmentProdBpSimple(){
INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT); INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT);
@ -1646,9 +1608,8 @@ public class MiscOpValidation extends BaseOpValidation {
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMmulRank4() throws Exception { public void testMmulRank4() throws Exception {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1683,9 +1644,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(outExp, out); assertEquals(outExp, out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMmulRank4_simple(){ public void testMmulRank4_simple(){
INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64); INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
@ -1711,9 +1671,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(exp, out); assertEquals(exp, out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNthElementRank1(){ public void testNthElementRank1(){
INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9}); INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9});
INDArray n = Nd4j.scalar(0); INDArray n = Nd4j.scalar(0);
@ -1735,9 +1694,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(0.0, out.getDouble(0), 1e-5); assertEquals(0.0, out.getDouble(0), 1e-5);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTensorMmulShape(){ public void testTensorMmulShape(){
INDArray a = Nd4j.create(new double[]{2}).reshape(1); INDArray a = Nd4j.create(new double[]{2}).reshape(1);
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2); INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
@ -1755,9 +1713,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertArrayEquals(new long[]{2,2}, l.get(0).getShape()); //Returning [1,2,2] assertArrayEquals(new long[]{2,2}, l.get(0).getShape()); //Returning [1,2,2]
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTensorMmulShape2(){ public void testTensorMmulShape2(){
INDArray a = Nd4j.create(new double[]{2}).reshape(1); INDArray a = Nd4j.create(new double[]{2}).reshape(1);
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2); INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
@ -1765,9 +1722,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertArrayEquals(new long[]{2,2}, c.shape()); assertArrayEquals(new long[]{2,2}, c.shape());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStopGradient(){ public void testStopGradient(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1786,9 +1742,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(Nd4j.zeros(DataType.DOUBLE, 3, 4), wArr); assertEquals(Nd4j.zeros(DataType.DOUBLE, 3, 4), wArr);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckNumerics(){ public void testCheckNumerics(){
OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927 OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927
@ -1831,9 +1786,8 @@ public class MiscOpValidation extends BaseOpValidation {
sd.outputAll(Collections.singletonMap("in", in)); sd.outputAll(Collections.singletonMap("in", in));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckNumerics2(Nd4jBackend backend) { public void testCheckNumerics2(Nd4jBackend backend) {
INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4); INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4);
INDArray msg = Nd4j.scalar("My error message!"); INDArray msg = Nd4j.scalar("My error message!");
@ -1846,9 +1800,8 @@ public class MiscOpValidation extends BaseOpValidation {
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testHistogramFixedWidth(){ public void testHistogramFixedWidth(){
//Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf] //Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf]
INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9); INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9);
@ -1866,9 +1819,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(exp, out); assertEquals(exp, out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDynamicPartition(){ public void testDynamicPartition(){
INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
@ -1886,9 +1838,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(exp2, out[2]); assertEquals(exp2, out[2]);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testListDiff(){ public void testListDiff(){
INDArray x = Nd4j.createFromArray(0, 1, 2, 3); INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
INDArray y = Nd4j.createFromArray(3, 1); INDArray y = Nd4j.createFromArray(3, 1);
@ -1907,9 +1858,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(exp, outIdx); //Indices of the values in x not in y assertEquals(exp, outIdx); //Indices of the values in x not in y
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDivideNoNan(Nd4jBackend backend) { public void testDivideNoNan(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff() OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff()
@ -1933,9 +1883,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDigamma(Nd4jBackend backend) { public void testDigamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1950,9 +1899,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testFlatten(Nd4jBackend backend) { public void testFlatten(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1974,9 +1922,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testFusedBatchNorm(Nd4jBackend backend) { public void testFusedBatchNorm(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2021,9 +1968,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIgamma(Nd4jBackend backend) { public void testIgamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -2039,9 +1985,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIgammaC(Nd4jBackend backend) { public void testIgammaC(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -2058,9 +2003,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLgamma(Nd4jBackend backend) { public void testLgamma(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2085,9 +2029,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLu(Nd4jBackend backend) { public void testLu(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2118,9 +2061,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMatrixBandPart(Nd4jBackend backend) { public void testMatrixBandPart(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2150,9 +2092,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPolygamma(Nd4jBackend backend) { public void testPolygamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -2168,9 +2109,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTriangularSolve(Nd4jBackend backend) { public void testTriangularSolve(Nd4jBackend backend) {
INDArray a = Nd4j.createFromArray(new float[]{ INDArray a = Nd4j.createFromArray(new float[]{
@ -2194,9 +2134,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBiasAdd(Nd4jBackend backend) { public void testBiasAdd(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2225,9 +2164,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBiasAddGrad(Nd4jBackend backend) { public void testBiasAddGrad(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2247,9 +2185,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRoll(Nd4jBackend backend) { public void testRoll(Nd4jBackend backend) {
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
@ -2269,9 +2206,8 @@ public class MiscOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSeqMask(){ public void testSeqMask(){
INDArray arr = Nd4j.createFromArray(1,2,3); INDArray arr = Nd4j.createFromArray(1,2,3);
INDArray maxLen = Nd4j.scalar(4); INDArray maxLen = Nd4j.scalar(4);

View File

@ -54,9 +54,8 @@ import static org.junit.jupiter.api.Assertions.*;
public class RandomOpValidation extends BaseOpValidation { public class RandomOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRandomOpsSDVarShape(Nd4jBackend backend) { public void testRandomOpsSDVarShape(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -157,9 +156,8 @@ public class RandomOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRandomOpsLongShape(Nd4jBackend backend) { public void testRandomOpsLongShape(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -285,9 +283,8 @@ public class RandomOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRandomBinomial(){ public void testRandomBinomial(){
INDArray z = Nd4j.create(new long[]{10}); INDArray z = Nd4j.create(new long[]{10});
@ -297,9 +294,8 @@ public class RandomOpValidation extends BaseOpValidation {
System.out.println(z); System.out.println(z);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUniformRankSimple(Nd4jBackend backend) { public void testUniformRankSimple(Nd4jBackend backend) {
INDArray arr = Nd4j.createFromArray(new double[]{100.0}); INDArray arr = Nd4j.createFromArray(new double[]{100.0});
@ -331,9 +327,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRandomExponential(Nd4jBackend backend) { public void testRandomExponential(Nd4jBackend backend) {
long length = 1_000_000; long length = 1_000_000;
INDArray shape = Nd4j.createFromArray(new double[]{length}); INDArray shape = Nd4j.createFromArray(new double[]{length});
@ -355,9 +350,8 @@ public class RandomOpValidation extends BaseOpValidation {
assertEquals( expStd, std, 0.1,"std"); assertEquals( expStd, std, 0.1,"std");
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRange(){ public void testRange(){
//Technically deterministic, not random... //Technically deterministic, not random...
@ -390,9 +384,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAllEmptyReduce(){ public void testAllEmptyReduce(){
INDArray x = Nd4j.createFromArray(true, true, true); INDArray x = Nd4j.createFromArray(true, true, true);
All all = new All(x); All all = new All(x);
@ -401,9 +394,8 @@ public class RandomOpValidation extends BaseOpValidation {
assertEquals(x, out); assertEquals(x, out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUniformDtype(){ public void testUniformDtype(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){ for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
@ -431,9 +423,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRandomExponential2(){ public void testRandomExponential2(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
DynamicCustomOp op = DynamicCustomOp.builder("random_exponential") DynamicCustomOp op = DynamicCustomOp.builder("random_exponential")

View File

@ -75,9 +75,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReduceSumBP(Nd4jBackend backend) { public void testReduceSumBP(Nd4jBackend backend) {
//Full array reduction //Full array reduction
@ -103,9 +102,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReduceSumAlongDim0BP(Nd4jBackend backend) { public void testReduceSumAlongDim0BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -131,9 +129,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReduceSumAlongDim1BP(Nd4jBackend backend) { public void testReduceSumAlongDim1BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -161,9 +158,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMeanBP(Nd4jBackend backend) { public void testMeanBP(Nd4jBackend backend) {
//dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j)) //dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j))
@ -194,9 +190,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMeanBP_Rank1(Nd4jBackend backend) { public void testMeanBP_Rank1(Nd4jBackend backend) {
INDArray dLdOut = Nd4j.scalar(0.5); INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
@ -209,9 +204,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMeanAlongDim0BP(Nd4jBackend backend) { public void testMeanAlongDim0BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -239,9 +233,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMeanAlongDim1BP(Nd4jBackend backend) { public void testMeanAlongDim1BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -269,9 +262,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMinBP(Nd4jBackend backend) { public void testMinBP(Nd4jBackend backend) {
//Full array min reduction //Full array min reduction
@ -310,9 +302,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMinAlongDimensionBP(Nd4jBackend backend) { public void testMinAlongDimensionBP(Nd4jBackend backend) {
//Full array min reduction //Full array min reduction
@ -355,9 +346,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMaxBP(Nd4jBackend backend) { public void testMaxBP(Nd4jBackend backend) {
//Full array max reduction //Full array max reduction
@ -387,9 +377,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMaxAlongDimensionBP(Nd4jBackend backend) { public void testMaxAlongDimensionBP(Nd4jBackend backend) {
//Full array min reduction //Full array min reduction
@ -432,9 +421,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testProdBP(Nd4jBackend backend) { public void testProdBP(Nd4jBackend backend) {
//Full array product reduction //Full array product reduction
@ -463,9 +451,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testProdAlongDimensionBP(Nd4jBackend backend) { public void testProdAlongDimensionBP(Nd4jBackend backend) {
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
// = dL/dOut * d(prod(in))/dIn_i // = dL/dOut * d(prod(in))/dIn_i
@ -521,9 +508,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStdevBP(Nd4jBackend backend) { public void testStdevBP(Nd4jBackend backend) {
//If out = stdev(in) then: //If out = stdev(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
@ -559,9 +545,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStdevBP_Rank1(Nd4jBackend backend) { public void testStdevBP_Rank1(Nd4jBackend backend) {
INDArray dLdOut = Nd4j.scalar(0.5); INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
@ -582,9 +567,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStdevAlongDimensionBP(Nd4jBackend backend) { public void testStdevAlongDimensionBP(Nd4jBackend backend) {
//If out = stdev(in) then: //If out = stdev(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
@ -629,9 +613,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVarianceBP(Nd4jBackend backend) { public void testVarianceBP(Nd4jBackend backend) {
//If out = variance(in) then: //If out = variance(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
@ -667,9 +650,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVarianceAlongDimensionBP(Nd4jBackend backend) { public void testVarianceAlongDimensionBP(Nd4jBackend backend) {
//If out = variance(in) then: //If out = variance(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
@ -711,9 +693,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCumSumBP(Nd4jBackend backend) { public void testCumSumBP(Nd4jBackend backend) {
//Standard case, non-reverse, non-exclusive //Standard case, non-reverse, non-exclusive
//dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i //dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i
@ -783,9 +764,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNorm2Bp(Nd4jBackend backend) { public void testNorm2Bp(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * x/|x|_2 // = dL/dOut * x/|x|_2
@ -812,9 +792,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNorm2AlongDimensionBP(Nd4jBackend backend) { public void testNorm2AlongDimensionBP(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * x/|x|_2 // = dL/dOut * x/|x|_2
@ -847,9 +826,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNorm1Bp(Nd4jBackend backend) { public void testNorm1Bp(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * sgn(in) // = dL/dOut * sgn(in)
@ -876,9 +854,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNorm1AlongDimensionBP(Nd4jBackend backend) { public void testNorm1AlongDimensionBP(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * sgn(in) // = dL/dOut * sgn(in)
@ -910,9 +887,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNormMaxBp(Nd4jBackend backend) { public void testNormMaxBp(Nd4jBackend backend) {
//out = max_i (|in_i|) //out = max_i (|in_i|)
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
@ -942,9 +918,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNormMaxAlongDimensionBP(Nd4jBackend backend) { public void testNormMaxAlongDimensionBP(Nd4jBackend backend) {
//out = max_i (|in_i|) //out = max_i (|in_i|)
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn

View File

@ -80,9 +80,8 @@ import static org.junit.jupiter.api.Assertions.*;
public class ReductionOpValidation extends BaseOpValidation { public class ReductionOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStdev(Nd4jBackend backend) { public void testStdev(Nd4jBackend backend) {
List<String> errors = new ArrayList<>(); List<String> errors = new ArrayList<>();
@ -108,9 +107,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertEquals(0, errors.size(),errors.toString()); assertEquals(0, errors.size(),errors.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testZeroCount(Nd4jBackend backend) { public void testZeroCount(Nd4jBackend backend) {
List<String> allFailed = new ArrayList<>(); List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 21; i++) { for (int i = 0; i < 21; i++) {
@ -144,9 +142,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testZeroFraction(Nd4jBackend backend) { public void testZeroFraction(Nd4jBackend backend) {
List<String> allFailed = new ArrayList<>(); List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
@ -176,9 +173,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertEquals(0, allFailed.size(),allFailed.toString()); assertEquals(0, allFailed.size(),allFailed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReductionGradientsSimple(Nd4jBackend backend) { public void testReductionGradientsSimple(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES //OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
//Test reductions: final and only function //Test reductions: final and only function
@ -347,9 +343,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReductionGradients1(Nd4jBackend backend) { public void testReductionGradients1(Nd4jBackend backend) {
//Test reductions: final, but *not* the only function //Test reductions: final, but *not* the only function
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -477,9 +472,8 @@ public class ReductionOpValidation extends BaseOpValidation {
return Long.MAX_VALUE; return Long.MAX_VALUE;
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReductionGradients2(Nd4jBackend backend) { public void testReductionGradients2(Nd4jBackend backend) {
//Test reductions: NON-final function //Test reductions: NON-final function
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -657,9 +651,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReduce3(Nd4jBackend backend) { public void testReduce3(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -764,9 +757,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),"Failed: " + failed); assertEquals(0, failed.size(),"Failed: " + failed);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMoments(Nd4jBackend backend) { public void testMoments(Nd4jBackend backend) {
for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) { for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) {
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -798,9 +790,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMomentsOp(Nd4jBackend backend) { public void testMomentsOp(Nd4jBackend backend) {
int[] axes = new int[]{0}; int[] axes = new int[]{0};
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -817,9 +808,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNormalizeMomentsOp(Nd4jBackend backend) { public void testNormalizeMomentsOp(Nd4jBackend backend) {
INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10); INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
INDArray ssSum = data.sum(0); INDArray ssSum = data.sum(0);
@ -839,9 +829,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAllAny(Nd4jBackend backend) { public void testAllAny(Nd4jBackend backend) {
INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4); INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4);
@ -869,9 +858,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIndexAccum(Nd4jBackend backend) { public void testIndexAccum(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/); List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/);
@ -960,9 +948,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReduce3_2(Nd4jBackend backend) { public void testReduce3_2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1060,9 +1047,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReductionsBackwards(Nd4jBackend backend) { public void testReductionsBackwards(Nd4jBackend backend) {
// for (int i = 0; i < 7; i++) { // for (int i = 0; i < 7; i++) {
int i=5; int i=5;
@ -1131,9 +1117,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDotProductAttention(){ public void testDotProductAttention(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1158,9 +1143,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDotProductAttentionWithMask(){ public void testDotProductAttentionWithMask(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1190,9 +1174,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDotProductAttentionMultiHeadInputWithMask(){ public void testDotProductAttentionMultiHeadInputWithMask(){
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
@ -1223,9 +1206,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDotProductAttentionMultiHeadInput(){ public void testDotProductAttentionMultiHeadInput(){
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
@ -1252,9 +1234,8 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMultiHeadedDotProductAttention(){ public void testMultiHeadedDotProductAttention(){
final INDArray k = Nd4j.rand(new int[]{10, 4, 5}); final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
final INDArray v = Nd4j.rand(new int[]{10, 4, 5}); final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
@ -1305,9 +1286,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDotProductAttentionWeirdInputs(){ public void testDotProductAttentionWeirdInputs(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1344,9 +1324,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMultiHeadedDotProductAttentionWeirdInputs(){ public void testMultiHeadedDotProductAttentionWeirdInputs(){
final INDArray k = Nd4j.rand(new int[]{10, 4, 5}); final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
final INDArray v = Nd4j.rand(new int[]{10, 4, 5}); final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
@ -1403,9 +1382,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSufficientStatisticsOp(Nd4jBackend backend) { public void testSufficientStatisticsOp(Nd4jBackend backend) {
INDArray data = Nd4j.createFromArray(new double[]{ INDArray data = Nd4j.createFromArray(new double[]{
5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1., 5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
@ -1431,9 +1409,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStandardDeviation(Nd4jBackend backend) { public void testStandardDeviation(Nd4jBackend backend) {
for (boolean keepDims : new boolean[]{false, true}) { for (boolean keepDims : new boolean[]{false, true}) {
@ -1460,9 +1437,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSquaredNorm(Nd4jBackend backend) { public void testSquaredNorm(Nd4jBackend backend) {
for (boolean keepDims : new boolean[]{false, true}) { for (boolean keepDims : new boolean[]{false, true}) {
@ -1485,9 +1461,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testShannonEntropy(Nd4jBackend backend) { public void testShannonEntropy(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695 OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
@ -1507,9 +1482,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEntropy(Nd4jBackend backend) { public void testEntropy(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1528,9 +1502,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAMean(Nd4jBackend backend) { public void testAMean(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1551,9 +1524,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMean(Nd4jBackend backend) { public void testMean(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1574,9 +1546,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNorm1(Nd4jBackend backend) { public void testNorm1(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1597,9 +1568,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNorm2(Nd4jBackend backend) { public void testNorm2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1620,9 +1590,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNormMax(Nd4jBackend backend) { public void testNormMax(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1643,9 +1612,8 @@ public class ReductionOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSoftmaxCrossEntropyWithLogitsLoss(Nd4jBackend backend) { public void testSoftmaxCrossEntropyWithLogitsLoss(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();

View File

@ -46,9 +46,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j @Slf4j
public class RnnOpValidation extends BaseOpValidation { public class RnnOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRnnBlockCell(Nd4jBackend backend) { public void testRnnBlockCell(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int mb = 2; int mb = 2;
@ -147,9 +146,8 @@ public class RnnOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRnnBlockCellManualTFCompare(Nd4jBackend backend) { public void testRnnBlockCellManualTFCompare(Nd4jBackend backend) {
//Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS" //Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS"
@ -211,9 +209,8 @@ public class RnnOpValidation extends BaseOpValidation {
assertEquals(out6, m.get(toExec.get(6))); //Output assertEquals(out6, m.get(toExec.get(6))); //Output
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGRUCell(){ public void testGRUCell(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int mb = 2; int mb = 2;

View File

@ -81,9 +81,8 @@ public class ShapeOpValidation extends BaseOpValidation {
doRepeat doRepeat
*/ */
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConcat(Nd4jBackend backend) { public void testConcat(Nd4jBackend backend) {
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2}; // int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
int[] concatDim = new int[]{0, 0, 0}; int[] concatDim = new int[]{0, 0, 0};
@ -123,9 +122,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals( 0, failed.size(),failed.toString()); assertEquals( 0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReshapeGradient(Nd4jBackend backend) { public void testReshapeGradient(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6873 //https://github.com/deeplearning4j/deeplearning4j/issues/6873
@ -161,9 +159,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPermuteGradient(Nd4jBackend backend) { public void testPermuteGradient(Nd4jBackend backend) {
int[] origShape = new int[]{3, 4, 5}; int[] origShape = new int[]{3, 4, 5};
@ -201,9 +198,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRank(){ public void testRank(){
List<long[]> inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5}); List<long[]> inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5});
@ -230,9 +226,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testExpandDimsGradient(Nd4jBackend backend) { public void testExpandDimsGradient(Nd4jBackend backend) {
val origShape = new long[]{3, 4}; val origShape = new long[]{3, 4};
@ -288,9 +283,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSqueezeGradient(Nd4jBackend backend) { public void testSqueezeGradient(Nd4jBackend backend) {
val origShape = new long[]{3, 4, 5}; val origShape = new long[]{3, 4, 5};
@ -354,9 +348,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSliceGradient(Nd4jBackend backend) { public void testSliceGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -446,9 +439,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStridedSliceGradient(Nd4jBackend backend) { public void testStridedSliceGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -511,9 +503,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMerge(Nd4jBackend backend) { public void testMerge(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -680,9 +671,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUnStack(Nd4jBackend backend) { public void testUnStack(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -770,9 +760,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals( 0, failed.size(),failed.toString()); assertEquals( 0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTile(Nd4jBackend backend) { public void testTile(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -844,9 +833,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTileBp(){ public void testTileBp(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -879,9 +867,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTileBp2(){ public void testTileBp2(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -915,9 +902,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReshape(Nd4jBackend backend) { public void testReshape(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4); INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4);
@ -933,9 +919,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReshape2(Nd4jBackend backend) { public void testReshape2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int[] origShape = new int[]{3, 4, 5}; int[] origShape = new int[]{3, 4, 5};
@ -958,9 +943,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTranspose(Nd4jBackend backend) { public void testTranspose(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
@ -972,9 +956,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTransposeOp(){ public void testTransposeOp(){
INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3); INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3);
@ -987,9 +970,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testShape(Nd4jBackend backend) { public void testShape(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
val shape = new long[]{2, 3}; val shape = new long[]{2, 3};
@ -1004,9 +986,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSize(Nd4jBackend backend) { public void testSize(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
val shape = new long[]{2, 3}; val shape = new long[]{2, 3};
@ -1020,9 +1001,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDiagShapeFn(Nd4jBackend backend) { public void testDiagShapeFn(Nd4jBackend backend) {
INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4); INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4);
@ -1036,9 +1016,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPermute(){ public void testPermute(){
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5); INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
INDArray exp = in.permute(0,1,2); //No op INDArray exp = in.permute(0,1,2); //No op
@ -1052,9 +1031,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(OpValidation.validate(op)); assertNull(OpValidation.validate(op));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPermute2(){ public void testPermute2(){
for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) { for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) {
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5); INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
@ -1074,9 +1052,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConstant(){ public void testConstant(){
//OpValidationSuite.ignoreFailing(); //OpValidationSuite.ignoreFailing();
@ -1103,9 +1080,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUnstackEdgeCase2(){ public void testUnstackEdgeCase2(){
for( int i=0; i<3; i++ ) { for( int i=0; i<3; i++ ) {
@ -1119,9 +1095,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void invertPermutation(Nd4jBackend backend) { public void invertPermutation(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1138,9 +1113,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGatherNd(){ public void testGatherNd(){
List<INDArray> indices = new ArrayList<>(); List<INDArray> indices = new ArrayList<>();
@ -1178,9 +1152,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReverseSequence(Nd4jBackend backend) { public void testReverseSequence(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
float[] input_data = new float[]{ float[] input_data = new float[]{
@ -1226,9 +1199,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMatrixDeterminant(){ public void testMatrixDeterminant(){
OpValidationSuite.ignoreFailing(); //Gradient check failing OpValidationSuite.ignoreFailing(); //Gradient check failing
@ -1249,9 +1221,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDeterminant22(){ public void testDeterminant22(){
OpValidationSuite.ignoreFailing(); //Gradient check failing OpValidationSuite.ignoreFailing(); //Gradient check failing
@ -1275,9 +1246,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMatrixDeterminant3(){ public void testMatrixDeterminant3(){
OpValidationSuite.ignoreFailing(); //Gradient checks failing OpValidationSuite.ignoreFailing(); //Gradient checks failing
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1308,9 +1278,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMatrixDeterminant4(){ public void testMatrixDeterminant4(){
OpValidationSuite.ignoreFailing(); //Gradient checks failing OpValidationSuite.ignoreFailing(); //Gradient checks failing
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1330,9 +1299,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentOps(){ public void testSegmentOps(){
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
//https://github.com/deeplearning4j/deeplearning4j/issues/6952 //https://github.com/deeplearning4j/deeplearning4j/issues/6952
@ -1424,9 +1392,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentMean(){ public void testSegmentMean(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3); INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3);
INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2); INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2);
@ -1446,9 +1413,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(exp, out); assertEquals(exp, out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSequenceMask(Nd4jBackend backend) { public void testSequenceMask(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2}); INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
@ -1482,9 +1448,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(expected, result2.eval()); assertEquals(expected, result2.eval());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMeshGrid(){ public void testMeshGrid(){
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -1540,9 +1505,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGather(){ public void testGather(){
List<INDArray> inArrs = new ArrayList<>(); List<INDArray> inArrs = new ArrayList<>();
List<Integer> axis = new ArrayList<>(); List<Integer> axis = new ArrayList<>();
@ -1611,9 +1575,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGatherSimple(Nd4jBackend backend) { public void testGatherSimple(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2}); INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2});
@ -1623,9 +1586,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(expected, result.eval()); assertEquals(expected, result.eval());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGatherNdSingle(Nd4jBackend backend) { public void testGatherNdSingle(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4);
@ -1644,9 +1606,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(expected, result.eval()); assertEquals(expected, result.eval());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStack2(Nd4jBackend backend) { public void testStack2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
@ -1657,9 +1618,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertArrayEquals(new long[]{3, 2, 2}, result.eval().shape()); assertArrayEquals(new long[]{3, 2, 2}, result.eval().shape());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testParallelStack(Nd4jBackend backend) { public void testParallelStack(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
@ -1671,9 +1631,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(Nd4j.concat(0, arr1, arr2).reshape(2, 3, 2), result.eval()); assertEquals(Nd4j.concat(0, arr1, arr2).reshape(2, 3, 2), result.eval());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUnStack2(Nd4jBackend backend) { public void testUnStack2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Nd4j.zeros(3, 2); INDArray arr1 = Nd4j.zeros(3, 2);
@ -1686,9 +1645,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(arr2, result[1].eval()); assertEquals(arr2, result[1].eval());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPermuteSimple(Nd4jBackend backend) { public void testPermuteSimple(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3)); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3));
@ -1699,9 +1657,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConcat2(Nd4jBackend backend) { public void testConcat2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
@ -1712,9 +1669,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertArrayEquals(new long[]{2, 4}, result.eval().shape()); assertArrayEquals(new long[]{2, 4}, result.eval().shape());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTile2(Nd4jBackend backend) { public void testTile2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4)); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4));
@ -1727,9 +1683,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSlice2d(Nd4jBackend backend) { public void testSlice2d(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
@ -1745,9 +1700,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSlice3d(Nd4jBackend backend) { public void testSlice3d(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
@ -1762,9 +1716,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), m.get(subPart.name())); assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), m.get(subPart.name()));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStridedSlice2dBasic(Nd4jBackend backend) { public void testStridedSlice2dBasic(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
@ -1782,9 +1735,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStridedSliceBeginEndMask(Nd4jBackend backend) { public void testStridedSliceBeginEndMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
@ -1799,9 +1751,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(inArr.get(NDArrayIndex.interval(1, 3), NDArrayIndex.all()), slice2.getArr()); assertEquals(inArr.get(NDArrayIndex.interval(1, 3), NDArrayIndex.all()), slice2.getArr());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStridedSliceEllipsisMask(Nd4jBackend backend) { public void testStridedSliceEllipsisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1818,9 +1769,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(inArr.get(interval(1, 3), all(), all()), slice2.getArr()); assertEquals(inArr.get(interval(1, 3), all(), all()), slice2.getArr());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStridedSliceNewAxisMask(Nd4jBackend backend) { public void testStridedSliceNewAxisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1833,9 +1783,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(inArr, out.get(point(0), all(), all(), all())); assertEquals(inArr, out.get(point(0), all(), all(), all()));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStridedSliceNewAxisMask2(Nd4jBackend backend) { public void testStridedSliceNewAxisMask2(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1846,9 +1795,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape()); assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStridedSliceShrinkAxisMask(Nd4jBackend backend) { public void testStridedSliceShrinkAxisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
@ -1865,9 +1813,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(inArr.get(point(1), point(2), interval(1, 5)).reshape(4), slice3.getArr()); assertEquals(inArr.get(point(1), point(2), interval(1, 5)).reshape(4), slice3.getArr());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSizeAt_1(Nd4jBackend backend) { public void testSizeAt_1(Nd4jBackend backend) {
val array = Nd4j.create(10, 20, 30); val array = Nd4j.create(10, 20, 30);
val exp = Nd4j.scalar(DataType.LONG, 20); val exp = Nd4j.scalar(DataType.LONG, 20);
@ -1881,9 +1828,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(exp, output); assertEquals(exp, output);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEye(){ public void testEye(){
int[] rows = new int[]{3,3,3,3}; int[] rows = new int[]{3,3,3,3};
int[] cols = new int[]{3,2,2,2}; int[] cols = new int[]{3,2,2,2};
@ -1921,9 +1867,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSplit1(){ public void testSplit1(){
INDArray in = Nd4j.linspace(1,10,10).reshape(10); INDArray in = Nd4j.linspace(1,10,10).reshape(10);
INDArray axis = Nd4j.scalar(-1); INDArray axis = Nd4j.scalar(-1);
@ -1941,9 +1886,8 @@ public class ShapeOpValidation extends BaseOpValidation {
.build()).expectedOutput(0, exp1).expectedOutput(1,exp2))); .build()).expectedOutput(0, exp1).expectedOutput(1,exp2)));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSplit2(){ public void testSplit2(){
INDArray in = Nd4j.linspace(1,24,24).reshape(3,8); INDArray in = Nd4j.linspace(1,24,24).reshape(3,8);
INDArray axis = Nd4j.scalar(-1); INDArray axis = Nd4j.scalar(-1);
@ -1961,9 +1905,8 @@ public class ShapeOpValidation extends BaseOpValidation {
.build()).expectedOutput(0, exp1).expectedOutput(1,exp2))); .build()).expectedOutput(0, exp1).expectedOutput(1,exp2)));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDistancesExec(){ public void testDistancesExec(){
//https://github.com/deeplearning4j/deeplearning4j/issues/7001 //https://github.com/deeplearning4j/deeplearning4j/issues/7001
for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) { for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) {
@ -2018,9 +1961,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReductionShape(){ public void testReductionShape(){
INDArray shape = Nd4j.createFromArray(4,2); INDArray shape = Nd4j.createFromArray(4,2);
@ -2038,9 +1980,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertArrayEquals(exp, s); //Fails - actual shape [1] assertArrayEquals(exp, s); //Fails - actual shape [1]
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void gatherTest(){ public void gatherTest(){
INDArray in = Nd4j.createFromArray(new double[][]{ INDArray in = Nd4j.createFromArray(new double[][]{
{1,2,3,4,5}, {1,2,3,4,5},
@ -2059,9 +2000,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertArrayEquals(expShape, shape); //Fails: actual shape: [5] assertArrayEquals(expShape, shape); //Fails: actual shape: [5]
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSliceShape(){ public void testSliceShape(){
INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT); INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT);
@ -2082,9 +2022,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertArrayEquals(shapeExp, shape); assertArrayEquals(shapeExp, shape);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testWhereAllFalse(){ public void testWhereAllFalse(){
INDArray in = Nd4j.create(DataType.BOOL, 1917); INDArray in = Nd4j.create(DataType.BOOL, 1917);
DynamicCustomOp op = DynamicCustomOp.builder("Where") DynamicCustomOp op = DynamicCustomOp.builder("Where")
@ -2098,9 +2037,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertTrue(isEmpty); //Not empty, but should be assertTrue(isEmpty); //Not empty, but should be
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGatherScalar(){ public void testGatherScalar(){
INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100); INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100);
INDArray indices = Nd4j.scalar(0); INDArray indices = Nd4j.scalar(0);
@ -2124,9 +2062,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(exp, arr); assertEquals(exp, arr);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCastEmpty(){ public void testCastEmpty(){
INDArray emptyLong = Nd4j.empty(DataType.LONG); INDArray emptyLong = Nd4j.empty(DataType.LONG);
int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h
@ -2142,9 +2079,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertTrue(isEmpty); assertTrue(isEmpty);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGatherEmpty(){ public void testGatherEmpty(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -2176,9 +2112,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertTrue(isEmpty); assertTrue(isEmpty);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSplitEmpty(){ public void testSplitEmpty(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -2215,9 +2150,8 @@ public class ShapeOpValidation extends BaseOpValidation {
Nd4j.exec(op); Nd4j.exec(op);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConcatEmpty(){ public void testConcatEmpty(){
/* /*
TF behaviour with concatenatioun of empty arrays: TF behaviour with concatenatioun of empty arrays:
@ -2266,9 +2200,8 @@ public class ShapeOpValidation extends BaseOpValidation {
Nd4j.exec(op); Nd4j.exec(op);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConcatEmpty2(){ public void testConcatEmpty2(){
INDArray empty10a = Nd4j.create(DataType.INT, 1, 0); INDArray empty10a = Nd4j.create(DataType.INT, 1, 0);
INDArray empty10b = Nd4j.create(DataType.INT, 1, 0); INDArray empty10b = Nd4j.create(DataType.INT, 1, 0);
@ -2300,9 +2233,8 @@ public class ShapeOpValidation extends BaseOpValidation {
Nd4j.exec(op); Nd4j.exec(op);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptyGather(){ public void testEmptyGather(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -2334,9 +2266,8 @@ public class ShapeOpValidation extends BaseOpValidation {
op.addOutputArgument(out); op.addOutputArgument(out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBroadcastDynamicShape1(){ public void testBroadcastDynamicShape1(){
//Test case: [2,1] and [4]: expect [2,4] //Test case: [2,1] and [4]: expect [2,4]
@ -2357,9 +2288,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(Nd4j.createFromArray(new int[]{2,4}), out); assertEquals(Nd4j.createFromArray(new int[]{2,4}), out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBroadcastDynamicShape2(){ public void testBroadcastDynamicShape2(){
//Test case: [2,1,4] and [2,2,4]: expect [2,2,4] //Test case: [2,1,4] and [2,2,4]: expect [2,2,4]
@ -2381,9 +2311,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(Nd4j.createFromArray(new int[]{2,4,3}), out); assertEquals(Nd4j.createFromArray(new int[]{2,4,3}), out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStridedSliceShrinkAxis(){ public void testStridedSliceShrinkAxis(){
INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2); INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2);
INDArray begin = Nd4j.createFromArray(2); INDArray begin = Nd4j.createFromArray(2);
@ -2408,9 +2337,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertArrayEquals(exp, shape); assertArrayEquals(exp, shape);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStridedSliceEmpty(){ public void testStridedSliceEmpty(){
INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask! INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask!
@ -2432,9 +2360,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertTrue(isEmpty); assertTrue(isEmpty);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStridedSliceEdgeCase(){ public void testStridedSliceEdgeCase(){
INDArray in = Nd4j.scalar(10).reshape(1); //Int [1] INDArray in = Nd4j.scalar(10).reshape(1); //Int [1]
INDArray begin = Nd4j.ones(DataType.INT, 1); INDArray begin = Nd4j.ones(DataType.INT, 1);
@ -2459,9 +2386,8 @@ public class ShapeOpValidation extends BaseOpValidation {
Nd4j.exec(op); //Execution is OK Nd4j.exec(op); //Execution is OK
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptySlice1(){ public void testEmptySlice1(){
INDArray in = Nd4j.createFromArray(38); INDArray in = Nd4j.createFromArray(38);
INDArray begin = Nd4j.createFromArray(1); INDArray begin = Nd4j.createFromArray(1);
@ -2480,9 +2406,8 @@ public class ShapeOpValidation extends BaseOpValidation {
Nd4j.exec(op); Nd4j.exec(op);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptySlice2(){ public void testEmptySlice2(){
INDArray in = Nd4j.createFromArray(38); INDArray in = Nd4j.createFromArray(38);
INDArray begin = Nd4j.createFromArray(0); INDArray begin = Nd4j.createFromArray(0);
@ -2501,9 +2426,8 @@ public class ShapeOpValidation extends BaseOpValidation {
Nd4j.exec(op); Nd4j.exec(op);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testFill(){ public void testFill(){
INDArray shape = Nd4j.createFromArray(0,4); INDArray shape = Nd4j.createFromArray(0,4);
@ -2522,9 +2446,8 @@ public class ShapeOpValidation extends BaseOpValidation {
Nd4j.exec(op); Nd4j.exec(op);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testFill2(){ public void testFill2(){
INDArray shape = Nd4j.createFromArray(0,4); INDArray shape = Nd4j.createFromArray(0,4);
@ -2541,9 +2464,8 @@ public class ShapeOpValidation extends BaseOpValidation {
Nd4j.exec(op); Nd4j.exec(op);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPermuteShapeDynamicAxis(){ public void testPermuteShapeDynamicAxis(){
DynamicCustomOp op = DynamicCustomOp.builder("permute") DynamicCustomOp op = DynamicCustomOp.builder("permute")
@ -2572,9 +2494,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertArrayEquals(new long[]{4, 5, 3}, l.get(0).getShape()); assertArrayEquals(new long[]{4, 5, 3}, l.get(0).getShape());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGather2(){ public void testGather2(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3)); SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3));
@ -2593,9 +2514,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPermute3(){ public void testPermute3(){
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0); INDArray permute = Nd4j.createFromArray(1,0);
@ -2613,9 +2533,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(exp, outArr); assertEquals(exp, outArr);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPermute4(){ public void testPermute4(){
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0); INDArray permute = Nd4j.createFromArray(1,0);
@ -2645,18 +2564,16 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInvertPermutation(){ public void testInvertPermutation(){
DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation") DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation")
.addInputs(Nd4j.createFromArray(1, 0)) .addInputs(Nd4j.createFromArray(1, 0))
.build(); .build();
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBroadcastInt1(Nd4jBackend backend) { public void testBroadcastInt1(Nd4jBackend backend) {
INDArray out = Nd4j.create(DataType.INT, 1); INDArray out = Nd4j.create(DataType.INT, 1);
@ -2669,9 +2586,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBroadcastInt2(){ public void testBroadcastInt2(){
INDArray out = Nd4j.create(DataType.INT, 2); INDArray out = Nd4j.create(DataType.INT, 2);
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape") DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
@ -2710,9 +2626,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMergeMaxIndex(Nd4jBackend backend) { public void testMergeMaxIndex(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -2729,9 +2644,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTriOp(Nd4jBackend backend) { public void testTriOp(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2743,9 +2657,8 @@ public class ShapeOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTriuOp(Nd4jBackend backend) { public void testTriuOp(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();

View File

@ -118,9 +118,8 @@ public class TransformOpValidation extends BaseOpValidation {
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScalarOps(Nd4jBackend backend) { public void testScalarOps(Nd4jBackend backend) {
int d0 = 2; int d0 = 2;
int d1 = 3; int d1 = 3;
@ -217,9 +216,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScalarMulCF(Nd4jBackend backend) { public void testScalarMulCF(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
@ -233,9 +231,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScalarMulCF2(Nd4jBackend backend) { public void testScalarMulCF2(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
@ -246,9 +243,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertEquals(outC, outF); assertEquals(outC, outF);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCross(Nd4jBackend backend) { public void testCross(Nd4jBackend backend) {
INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3}); INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3});
INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3}); INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3});
@ -276,9 +272,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err, err); assertNull(err, err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSpaceToDepth(Nd4jBackend backend) { public void testSpaceToDepth(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
@ -306,9 +301,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDepthToSpace(Nd4jBackend backend) { public void testDepthToSpace(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
@ -335,9 +329,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err, err); assertNull(err, err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBatchToSpace(Nd4jBackend backend) { public void testBatchToSpace(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863 //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
@ -374,9 +367,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err, err); assertNull(err, err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSpaceToBatch(Nd4jBackend backend) { public void testSpaceToBatch(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863 //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
@ -414,9 +406,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err, err); assertNull(err, err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDynamicPartition(Nd4jBackend backend) { public void testDynamicPartition(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -456,9 +447,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDynamicPartition2(Nd4jBackend backend) { public void testDynamicPartition2(Nd4jBackend backend) {
INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
@ -476,9 +466,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertEquals(exp2, out[2]); assertEquals(exp2, out[2]);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDynamicStitch(Nd4jBackend backend) { public void testDynamicStitch(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -515,9 +504,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDiag(Nd4jBackend backend) { public void testDiag(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -543,9 +531,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDiagPart(Nd4jBackend backend) { public void testDiagPart(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -564,9 +551,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEye(Nd4jBackend backend) { public void testEye(Nd4jBackend backend) {
int[] rows = new int[]{3, 3, 3, 3}; int[] rows = new int[]{3, 3, 3, 3};
int[] cols = new int[]{3, 2, 2, 2}; int[] cols = new int[]{3, 2, 2, 2};
@ -600,9 +586,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEyeShape(Nd4jBackend backend) { public void testEyeShape(Nd4jBackend backend) {
DynamicCustomOp dco = DynamicCustomOp.builder("eye") DynamicCustomOp dco = DynamicCustomOp.builder("eye")
.addIntegerArguments(3, 3) .addIntegerArguments(3, 3)
@ -614,9 +599,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertArrayEquals(new long[]{3, 3}, list.get(0).getShape()); assertArrayEquals(new long[]{3, 3}, list.get(0).getShape());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTransforms(Nd4jBackend backend) { public void testTransforms(Nd4jBackend backend) {
//Test transforms (non-pairwise) //Test transforms (non-pairwise)
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1104,9 +1088,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPairwiseTransforms(Nd4jBackend backend) { public void testPairwiseTransforms(Nd4jBackend backend) {
/* /*
add, sub, mul, div, rsub, rdiv add, sub, mul, div, rsub, rdiv
@ -1290,9 +1273,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIsX(Nd4jBackend backend) { public void testIsX(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -1347,9 +1329,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertEquals(0, failed.size(),failed.toString()); assertEquals(0, failed.size(),failed.toString());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReplaceWhereScalar(Nd4jBackend backend) { public void testReplaceWhereScalar(Nd4jBackend backend) {
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
@ -1371,9 +1352,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReplaceWhereArray(Nd4jBackend backend) { public void testReplaceWhereArray(Nd4jBackend backend) {
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
@ -1396,9 +1376,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLogGrad(Nd4jBackend backend) { public void testLogGrad(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE)); SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE));
@ -1409,9 +1388,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSigmoidBackwards(Nd4jBackend backend) { public void testSigmoidBackwards(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
@ -1429,9 +1407,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
/* @Test /* @ParameterizedTest
@ParameterizedTest @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepth(Nd4jBackend backend) { public void testDepth(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
SDVariable x = sameDiff.one("one",new long[]{2,2}); SDVariable x = sameDiff.one("one",new long[]{2,2});
@ -1440,9 +1417,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertEquals(1,sigmoid.depth()); assertEquals(1,sigmoid.depth());
}*/ }*/
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRank0EdgeCase(Nd4jBackend backend) { public void testRank0EdgeCase(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))); SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4})));
@ -1455,9 +1431,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertEquals(4, d1, 0); assertEquals(4, d1, 0);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAtan2BroadcastShape(Nd4jBackend backend) { public void testAtan2BroadcastShape(Nd4jBackend backend) {
INDArray arr1 = Nd4j.create(new long[]{3, 1, 4}); INDArray arr1 = Nd4j.create(new long[]{3, 1, 4});
INDArray arr2 = Nd4j.create(new long[]{1, 2, 4}); INDArray arr2 = Nd4j.create(new long[]{1, 2, 4});
@ -1472,9 +1447,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertArrayEquals(new long[]{3, 2, 4}, outShapes.get(0).getShape(),Arrays.toString(outShapes.get(0).getShape())); assertArrayEquals(new long[]{3, 2, 4}, outShapes.get(0).getShape(),Arrays.toString(outShapes.get(0).getShape()));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBooleanAnd(Nd4jBackend backend) { public void testBooleanAnd(Nd4jBackend backend) {
Nd4j.setDataType(DataType.FLOAT); Nd4j.setDataType(DataType.FLOAT);
INDArray arr1 = Nd4j.create(new long[]{3, 4}); INDArray arr1 = Nd4j.create(new long[]{3, 4});
@ -1488,9 +1462,8 @@ public class TransformOpValidation extends BaseOpValidation {
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScatterOpsScalar(Nd4jBackend backend) { public void testScatterOpsScalar(Nd4jBackend backend) {
for (String s : new String[]{"add", "sub", "mul", "div"}) { for (String s : new String[]{"add", "sub", "mul", "div"}) {
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3); INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
@ -1535,9 +1508,8 @@ public class TransformOpValidation extends BaseOpValidation {
@Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540") @Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540")
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPad(Nd4jBackend backend) { public void testPad(Nd4jBackend backend) {
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0); INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG); INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG);
@ -1564,9 +1536,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMirrorPad(Nd4jBackend backend) { public void testMirrorPad(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
@ -1599,9 +1570,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err2); assertNull(err2);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMirrorPad2(Nd4jBackend backend) { public void testMirrorPad2(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
@ -1627,9 +1597,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMirrorPadSymmetric(Nd4jBackend backend) { public void testMirrorPadSymmetric(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4); INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT);
@ -1656,9 +1625,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUnique(Nd4jBackend backend) { public void testUnique(Nd4jBackend backend) {
INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4}); INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4});
@ -1680,9 +1648,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTopK(Nd4jBackend backend) { public void testTopK(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //Can't assume sorted here OpValidationSuite.ignoreFailing(); //Can't assume sorted here
INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8}); INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8});
@ -1711,9 +1678,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTopK1(Nd4jBackend backend) { public void testTopK1(Nd4jBackend backend) {
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
INDArray k = Nd4j.scalar(1); INDArray k = Nd4j.scalar(1);
@ -1734,9 +1700,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertEquals(expIdx, outIdx); assertEquals(expIdx, outIdx);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInTopK(Nd4jBackend backend) { public void testInTopK(Nd4jBackend backend) {
for (int k = 4; k >= 1; k--) { for (int k = 4; k >= 1; k--) {
log.info("Testing: k=" + k); log.info("Testing: k=" + k);
@ -1777,9 +1742,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testZeta(Nd4jBackend backend) { public void testZeta(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182 OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182
INDArray x = Nd4j.rand(3, 4).addi(1.0); INDArray x = Nd4j.rand(3, 4).addi(1.0);
@ -1796,9 +1760,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNotEquals(Nd4j.create(out.shape()), out); assertNotEquals(Nd4j.create(out.shape()), out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMaxEmptyScalar(Nd4jBackend backend) { public void testMaxEmptyScalar(Nd4jBackend backend) {
INDArray empty = Nd4j.empty(DataType.FLOAT); INDArray empty = Nd4j.empty(DataType.FLOAT);
INDArray scalar = Nd4j.scalar(1.0f); INDArray scalar = Nd4j.scalar(1.0f);
@ -1815,9 +1778,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertTrue(isEmpty); assertTrue(isEmpty);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBroadcastEmpty(Nd4jBackend backend) { public void testBroadcastEmpty(Nd4jBackend backend) {
// Nd4j.getExecutioner().enableVerboseMode(true); // Nd4j.getExecutioner().enableVerboseMode(true);
// Nd4j.getExecutioner().enableDebugMode(true); // Nd4j.getExecutioner().enableDebugMode(true);
@ -1907,9 +1869,8 @@ public class TransformOpValidation extends BaseOpValidation {
return false; return false;
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStandardize(Nd4jBackend backend) { public void testStandardize(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4}); final INDArray random = Nd4j.rand(new int[]{10, 4});
@ -1930,9 +1891,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err, err); assertNull(err, err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStandardizeOP(Nd4jBackend backend) { public void testStandardizeOP(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4}); final INDArray random = Nd4j.rand(new int[]{10, 4});
@ -1947,9 +1907,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertEquals(res, output); assertEquals(res, output);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStandardizeNoDeviation(Nd4jBackend backend) { public void testStandardizeNoDeviation(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4}); final INDArray random = Nd4j.rand(new int[]{10, 4});
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
@ -1975,9 +1934,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err, err); assertNull(err, err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMatMulTensor(Nd4jBackend backend) { public void testMatMulTensor(Nd4jBackend backend) {
final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5}); final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5});
final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6}); final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6});
@ -1997,9 +1955,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err, err); assertNull(err, err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMatMulTensorTranspose(Nd4jBackend backend) { public void testMatMulTensorTranspose(Nd4jBackend backend) {
for (boolean transposeA : new boolean[]{false, true}) { for (boolean transposeA : new boolean[]{false, true}) {
for (boolean transposeB : new boolean[]{false, true}) { for (boolean transposeB : new boolean[]{false, true}) {
@ -2092,9 +2049,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSoftmaxCF(Nd4jBackend backend) { public void testSoftmaxCF(Nd4jBackend backend) {
INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5); INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5);
@ -2115,9 +2071,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertEquals(outCC, outFF); assertEquals(outCC, outFF);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLogSumExp(Nd4jBackend backend) { public void testLogSumExp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4); INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4);
@ -2132,9 +2087,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertEquals(log, out); assertEquals(log, out);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLogSumExp2(Nd4jBackend backend) { public void testLogSumExp2(Nd4jBackend backend) {
for (int dim = 0; dim <= 2; dim++) { for (int dim = 0; dim <= 2; dim++) {
@ -2155,9 +2109,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCRELU(Nd4jBackend backend) { public void testCRELU(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -2176,9 +2129,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testClipByAvgNorm(Nd4jBackend backend) { public void testClipByAvgNorm(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -2199,9 +2151,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmbeddingLookup(Nd4jBackend backend) { public void testEmbeddingLookup(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -2214,9 +2165,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testImageResize(Nd4jBackend backend) { public void testImageResize(Nd4jBackend backend) {
//TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea //TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea
@ -2258,9 +2208,8 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMaximumBp(Nd4jBackend backend) { public void testMaximumBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -2277,9 +2226,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMergeAddBp(Nd4jBackend backend) { public void testMergeAddBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -2296,9 +2244,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMergeMaxBp(Nd4jBackend backend) { public void testMergeMaxBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -2316,9 +2263,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMergeAvgBp(Nd4jBackend backend) { public void testMergeAvgBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -2335,9 +2281,8 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReverseBp(Nd4jBackend backend) { public void testReverseBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -2351,9 +2296,8 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err); assertNull(err);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUpsampling3dBp(Nd4jBackend backend) { public void testUpsampling3dBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -45,9 +45,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDeConv2D(Nd4jBackend backend){ public void testDeConv2D(Nd4jBackend backend){
DeConv2DConfig.builder().kH(2).kW(4).build(); DeConv2DConfig.builder().kH(2).kW(4).build();
@ -108,9 +107,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv2D(Nd4jBackend backend){ public void testConv2D(Nd4jBackend backend){
Conv2DConfig.builder().kH(2).kW(4).build(); Conv2DConfig.builder().kH(2).kW(4).build();
@ -171,9 +169,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPooling2D(Nd4jBackend backend){ public void testPooling2D(Nd4jBackend backend){
Pooling2DConfig.builder().kH(2).kW(4).build(); Pooling2DConfig.builder().kH(2).kW(4).build();
@ -234,9 +231,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDeConv3D(Nd4jBackend backend){ public void testDeConv3D(Nd4jBackend backend){
DeConv3DConfig.builder().kH(2).kW(4).kD(3).build(); DeConv3DConfig.builder().kH(2).kW(4).kD(3).build();
@ -325,9 +321,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv3D(Nd4jBackend backend){ public void testConv3D(Nd4jBackend backend){
Conv3DConfig.builder().kH(2).kW(4).kD(3).build(); Conv3DConfig.builder().kH(2).kW(4).kD(3).build();
@ -418,9 +413,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPooling3D(Nd4jBackend backend){ public void testPooling3D(Nd4jBackend backend){
Pooling3DConfig.builder().kH(2).kW(4).kD(3).build(); Pooling3DConfig.builder().kH(2).kW(4).kD(3).build();
@ -509,9 +503,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv1D(){ public void testConv1D(){
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build(); Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();

View File

@ -50,9 +50,8 @@ public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEye(Nd4jBackend backend){ public void testEye(Nd4jBackend backend){
//OpValidationSuite.ignoreFailing(); //OpValidationSuite.ignoreFailing();
INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3}); INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3});
@ -68,9 +67,8 @@ public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
assertEquals(expOut, result.eval()); assertEquals(expOut, result.eval());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEyeShape(Nd4jBackend backend){ public void testEyeShape(Nd4jBackend backend){
val dco = DynamicCustomOp.builder("eye") val dco = DynamicCustomOp.builder("eye")
.addIntegerArguments(3,3) .addIntegerArguments(3,3)
@ -82,9 +80,8 @@ public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
assertArrayEquals(new long[]{3,3}, list.get(0).getShape()); assertArrayEquals(new long[]{3,3}, list.get(0).getShape());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testExecutionDifferentShapesTransform(Nd4jBackend backend){ public void testExecutionDifferentShapesTransform(Nd4jBackend backend){
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -105,9 +102,8 @@ public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
assertEquals(exp, out2); assertEquals(exp, out2);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDropout(Nd4jBackend backend) { public void testDropout(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -120,9 +116,8 @@ public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
assertArrayEquals(new long[]{2, 2}, res.getShape()); assertArrayEquals(new long[]{2, 2}, res.getShape());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testExecutionDifferentShapesDynamicCustom(Nd4jBackend backend){ public void testExecutionDifferentShapesDynamicCustom(Nd4jBackend backend){
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();

View File

@ -67,9 +67,7 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static junit.framework.TestCase.assertNotNull; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends { public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
@ -82,9 +80,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4); INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
@ -139,9 +136,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
assertEquals(sd.getLossVariables().size(), fg.lossVariablesLength()); assertEquals(sd.getLossVariables().size(), fg.lossVariablesLength());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
for( int i = 0; i < 10; i++ ) { for( int i = 0; i < 10; i++ ) {
for(boolean execFirst : new boolean[]{false, true}) { for(boolean execFirst : new boolean[]{false, true}) {
@ -270,9 +266,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
//Ensure 2 things: //Ensure 2 things:
@ -356,9 +351,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void pooling3DSerialization(Nd4jBackend backend){ public void pooling3DSerialization(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -378,9 +372,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
deserialized.getVariableOutputOp("pool").getClass()); deserialized.getVariableOutputOp("pool").getClass());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void pooling3DSerialization2(Nd4jBackend backend){ public void pooling3DSerialization2(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();

View File

@ -52,9 +52,8 @@ public class GraphTransformUtilTests extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBasic(Nd4jBackend backend){ public void testBasic(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -93,9 +92,8 @@ public class GraphTransformUtilTests extends BaseNd4jTestWithBackends {
assertEquals(0, sg2.getChildNodes().size()); assertEquals(0, sg2.getChildNodes().size());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSubgraphReplace1(Nd4jBackend backend){ public void testSubgraphReplace1(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();

View File

@ -42,9 +42,8 @@ public class MemoryMgrTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testArrayReuseTooLarge(Nd4jBackend backend) throws Exception { public void testArrayReuseTooLarge(Nd4jBackend backend) throws Exception {
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr(); ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
@ -116,9 +115,8 @@ public class MemoryMgrTest extends BaseNd4jTestWithBackends {
assertEquals(0, mmgr.getLruCacheValues().size()); assertEquals(0, mmgr.getLruCacheValues().size());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testManyArrays(Nd4jBackend backend){ public void testManyArrays(Nd4jBackend backend){
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr(); ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();

View File

@ -45,9 +45,8 @@ public class NameScopeTests extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVariableNameScopesBasic(Nd4jBackend backend) { public void testVariableNameScopesBasic(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -73,9 +72,8 @@ public class NameScopeTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOpFieldsAndNames(Nd4jBackend backend) { public void testOpFieldsAndNames(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -153,9 +151,8 @@ public class NameScopeTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNoNesting(Nd4jBackend backend) { public void testNoNesting(Nd4jBackend backend) {
SameDiff SD = SameDiff.create(); SameDiff SD = SameDiff.create();
@ -172,9 +169,8 @@ public class NameScopeTests extends BaseNd4jTestWithBackends {
assertTrue(SD.variableMap().containsKey("test/argmax"),"Var with name test/argmax exists"); assertTrue(SD.variableMap().containsKey("test/argmax"),"Var with name test/argmax exists");
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNoTesting2(Nd4jBackend backend) { public void testNoTesting2(Nd4jBackend backend) {
SameDiff SD = SameDiff.create(); SameDiff SD = SameDiff.create();

View File

@ -49,9 +49,8 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
return Long.MAX_VALUE; return Long.MAX_VALUE;
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimple(Nd4jBackend backend) throws Exception { public void testSimple(Nd4jBackend backend) throws Exception {
int nThreads = 4; int nThreads = 4;

View File

@ -36,9 +36,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class SameDiffOutputTest extends BaseNd4jTestWithBackends { public class SameDiffOutputTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void outputTest(Nd4jBackend backend){ public void outputTest(Nd4jBackend backend){
DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10)); DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10));
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();

View File

@ -33,8 +33,6 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.assertNull;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends { public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
@ -45,9 +43,8 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSpecifiedLoss1(Nd4jBackend backend) { public void testSpecifiedLoss1(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4); SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4);
@ -68,9 +65,8 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
assertNotNull(ph1.gradient()); assertNotNull(ph1.gradient());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSpecifiedLoss2(Nd4jBackend backend) { public void testSpecifiedLoss2(Nd4jBackend backend) {
for( int i = 0; i < 2; i++) { for( int i = 0; i < 2; i++) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -111,7 +107,7 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
for(String s : new String[]{"w", "b", badd.name(), add.name(), "l1", "l2"}){ for(String s : new String[]{"w", "b", badd.name(), add.name(), "l1", "l2"}){
SDVariable gradVar = sd.getVariable(s).gradient(); SDVariable gradVar = sd.getVariable(s).gradient();
assertNotNull(s, gradVar); assertNotNull(gradVar,s);
} }
//Unused: //Unused:
assertFalse(shape.hasGradient()); assertFalse(shape.hasGradient());
@ -123,9 +119,8 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTrainingDifferentLosses(Nd4jBackend backend) { public void testTrainingDifferentLosses(Nd4jBackend backend) {
//Net with 2 losses: train on the first one, then change losses //Net with 2 losses: train on the first one, then change losses
//Also check that if modifying via add/setLossVariables the training config changes //Also check that if modifying via add/setLossVariables the training config changes
@ -154,20 +149,20 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
sd.setLossVariables("loss1"); sd.setLossVariables("loss1");
sd.createGradFunction(); sd.createGradFunction();
for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){ for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){
assertNotNull(v.name(), v.gradient()); assertNotNull(v.gradient(),v.name());
} }
for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){ for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){
assertNull(v.name(), v.gradient()); assertNull(v.gradient(),v.name());
} }
//Now, set to other loss function //Now, set to other loss function
sd.setLossVariables("loss2"); sd.setLossVariables("loss2");
sd.createGradFunction(); sd.createGradFunction();
for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){ for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){
assertNull(v.name(), v.gradient()); assertNull(v.gradient(),v.name());
} }
for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){ for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){
assertNotNull(v.name(), v.gradient()); assertNotNull(v.gradient(),v.name());
} }
//Train the first side of the graph. The other side should remain unmodified! //Train the first side of the graph. The other side should remain unmodified!

View File

@ -60,9 +60,8 @@ import org.nd4j.weightinit.impl.XavierInitScheme;
public class SameDiffTrainingTest extends BaseNd4jTestWithBackends { public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void irisTrainingSanityCheck(Nd4jBackend backend) { public void irisTrainingSanityCheck(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -134,9 +133,8 @@ public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void irisTrainingEvalTest(Nd4jBackend backend) { public void irisTrainingEvalTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -186,9 +184,8 @@ public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void irisTrainingValidationTest(Nd4jBackend backend) { public void irisTrainingValidationTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -243,9 +240,8 @@ public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTrainingMixedDtypes(){ public void testTrainingMixedDtypes(){
for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) { for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) {
@ -307,9 +303,8 @@ public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void simpleClassification(Nd4jBackend backend) { public void simpleClassification(Nd4jBackend backend) {
double learning_rate = 0.001; double learning_rate = 0.001;
int seed = 7; int seed = 7;
@ -356,9 +351,8 @@ public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1); History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTrainingEvalVarNotReqForLoss(){ public void testTrainingEvalVarNotReqForLoss(){
//If a variable is not required for the loss - normally it won't be calculated //If a variable is not required for the loss - normally it won't be calculated
//But we want to make sure it IS calculated here - so we can perform evaluation on it //But we want to make sure it IS calculated here - so we can perform evaluation on it

View File

@ -20,7 +20,7 @@
package org.nd4j.autodiff.samediff.listeners; package org.nd4j.autodiff.samediff.listeners;
import org.junit.Assert;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -47,8 +47,9 @@ import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class CheckpointListenerTest extends BaseNd4jTestWithBackends { public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
@ -94,9 +95,8 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
@ -130,9 +130,8 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
assertTrue(found1 && found2 && found3); assertTrue(found1 && found2 && found3);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
@ -166,14 +165,13 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
assertEquals(5, files.length); //4 checkpoints and 1 text file (metadata) assertEquals(5, files.length); //4 checkpoints and 1 text file (metadata)
for( int i = 0; i < found.length; i++) { for( int i = 0; i < found.length; i++) {
assertTrue(names.get(i), found[i]); assertTrue(found[i], names.get(i));
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();
@ -213,13 +211,12 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
} }
for( int i = 0; i < found.length; i++) { for( int i = 0; i < found.length; i++) {
assertTrue(names.get(i), found[i]); assertTrue(found[i], names.get(i));
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();
@ -258,8 +255,8 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
} }
assertEquals(5, cpNums.size(),cpNums.toString()); assertEquals(5, cpNums.size(),cpNums.toString());
Assert.assertTrue(cpNums.toString(), cpNums.containsAll(Arrays.asList(2, 5, 7, 8, 9))); assertTrue(cpNums.containsAll(Arrays.asList(2, 5, 7, 8, 9)), cpNums.toString());
Assert.assertTrue(epochNums.toString(), epochNums.containsAll(Arrays.asList(5, 11, 15, 17, 19))); assertTrue(epochNums.containsAll(Arrays.asList(5, 11, 15, 17, 19)), epochNums.toString());
assertEquals(5, l.availableCheckpoints().size()); assertEquals(5, l.availableCheckpoints().size());
} }

View File

@ -38,9 +38,8 @@ import org.nd4j.linalg.learning.config.Adam;
public class ExecDebuggingListenerTest extends BaseNd4jTestWithBackends { public class ExecDebuggingListenerTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testExecDebugListener(Nd4jBackend backend) { public void testExecDebugListener(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();

View File

@ -71,9 +71,8 @@ public class ListenerTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void irisHistoryTest(Nd4jBackend backend) { public void irisHistoryTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -136,9 +135,8 @@ public class ListenerTest extends BaseNd4jTestWithBackends {
assertTrue(acc >= 0.75,"Accuracy < 75%, was " + acc); assertTrue(acc >= 0.75,"Accuracy < 75%, was " + acc);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testListenerCalls(){ public void testListenerCalls(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
@ -275,9 +273,8 @@ public class ListenerTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCustomListener(Nd4jBackend backend) { public void testCustomListener(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4); SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);

View File

@ -57,9 +57,8 @@ public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -108,25 +107,22 @@ public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
} }
/* /*
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLoadTfProfile(){ public void testLoadTfProfile(){
File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json"); File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json");
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLoadTfProfileDir(){ public void testLoadTfProfileDir(){
File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles"); File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLoadTfProfileDir2(){ public void testLoadTfProfileDir2(){
File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0"); File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);

View File

@ -79,9 +79,8 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException { public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4); SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
@ -185,9 +184,8 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{ public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
File dir = testDir.toFile(); File dir = testDir.toFile();
File f = new File(dir, "temp.bin"); File f = new File(dir, "temp.bin");

View File

@ -63,9 +63,8 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -101,9 +100,8 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
assertArrayEquals(new long[]{150, 3}, out.shape()); assertArrayEquals(new long[]{150, 3}, out.shape());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -194,9 +192,8 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet(); SameDiff sd1 = getSimpleNet();

View File

@ -40,9 +40,8 @@ public class CustomEvaluationTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void customEvalTest(Nd4jBackend backend){ public void customEvalTest(Nd4jBackend backend){
CustomEvaluation accuracyEval = new CustomEvaluation<>( CustomEvaluation accuracyEval = new CustomEvaluation<>(
(labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)), (labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)),

View File

@ -45,9 +45,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptyEvaluation (Nd4jBackend backend) { public void testEmptyEvaluation (Nd4jBackend backend) {
Evaluation e = new Evaluation(); Evaluation e = new Evaluation();
System.out.println(e.stats()); System.out.println(e.stats());
@ -62,9 +61,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptyRegressionEvaluation (Nd4jBackend backend) { public void testEmptyRegressionEvaluation (Nd4jBackend backend) {
RegressionEvaluation re = new RegressionEvaluation(); RegressionEvaluation re = new RegressionEvaluation();
re.stats(); re.stats();
@ -78,9 +76,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptyEvaluationBinary(Nd4jBackend backend) { public void testEmptyEvaluationBinary(Nd4jBackend backend) {
EvaluationBinary eb = new EvaluationBinary(); EvaluationBinary eb = new EvaluationBinary();
eb.stats(); eb.stats();
@ -95,9 +92,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptyROC(Nd4jBackend backend) { public void testEmptyROC(Nd4jBackend backend) {
ROC roc = new ROC(); ROC roc = new ROC();
roc.stats(); roc.stats();
@ -112,9 +108,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptyROCBinary(Nd4jBackend backend) { public void testEmptyROCBinary(Nd4jBackend backend) {
ROCBinary rb = new ROCBinary(); ROCBinary rb = new ROCBinary();
rb.stats(); rb.stats();
@ -129,9 +124,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptyROCMultiClass(Nd4jBackend backend) { public void testEmptyROCMultiClass(Nd4jBackend backend) {
ROCMultiClass r = new ROCMultiClass(); ROCMultiClass r = new ROCMultiClass();
r.stats(); r.stats();
@ -146,9 +140,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptyEvaluationCalibration(Nd4jBackend backend) { public void testEmptyEvaluationCalibration(Nd4jBackend backend) {
EvaluationCalibration ec = new EvaluationCalibration(); EvaluationCalibration ec = new EvaluationCalibration();
ec.stats(); ec.stats();

View File

@ -46,9 +46,8 @@ public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationCustomBinaryThreshold(Nd4jBackend backend) { public void testEvaluationCustomBinaryThreshold(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -114,9 +113,8 @@ public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
assertEquals(ex2.getConfusionMatrix(), e025v2.getConfusionMatrix()); assertEquals(ex2.getConfusionMatrix(), e025v2.getConfusionMatrix());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationCostArray(Nd4jBackend backend) { public void testEvaluationCostArray(Nd4jBackend backend) {
@ -164,9 +162,8 @@ public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
assertEquals(1.0, e2.accuracy(), 1e-6); assertEquals(1.0, e2.accuracy(), 1e-6);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationBinaryCustomThreshold(Nd4jBackend backend) { public void testEvaluationBinaryCustomThreshold(Nd4jBackend backend) {
//Sanity check: same results for 0.5 threshold vs. default (no threshold) //Sanity check: same results for 0.5 threshold vs. default (no threshold)

View File

@ -39,9 +39,7 @@ import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static junit.framework.TestCase.assertNull; import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class EvalJsonTest extends BaseNd4jTestWithBackends { public class EvalJsonTest extends BaseNd4jTestWithBackends {
@ -52,9 +50,8 @@ public class EvalJsonTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSerdeEmpty(Nd4jBackend backend) { public void testSerdeEmpty(Nd4jBackend backend) {
boolean print = false; boolean print = false;
@ -74,9 +71,8 @@ public class EvalJsonTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSerde(Nd4jBackend backend) { public void testSerde(Nd4jBackend backend) {
boolean print = false; boolean print = false;
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -124,9 +120,8 @@ public class EvalJsonTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSerdeExactRoc(Nd4jBackend backend) { public void testSerdeExactRoc(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
boolean print = false; boolean print = false;
@ -204,9 +199,8 @@ public class EvalJsonTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testJsonYamlCurves(Nd4jBackend backend) { public void testJsonYamlCurves(Nd4jBackend backend) {
ROC roc = new ROC(0); ROC roc = new ROC(0);
@ -258,9 +252,8 @@ public class EvalJsonTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testJsonWithCustomThreshold(Nd4jBackend backend) { public void testJsonWithCustomThreshold(Nd4jBackend backend) {
//Evaluation - binary threshold //Evaluation - binary threshold

View File

@ -50,9 +50,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEval(Nd4jBackend backend) { public void testEval(Nd4jBackend backend) {
int classNum = 5; int classNum = 5;
Evaluation eval = new Evaluation (classNum); Evaluation eval = new Evaluation (classNum);
@ -91,9 +90,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
assertEquals(0.5, eval.accuracy(), 0); assertEquals(0.5, eval.accuracy(), 0);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEval2(Nd4jBackend backend) { public void testEval2(Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
@ -152,9 +150,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStringListLabels(Nd4jBackend backend) { public void testStringListLabels(Nd4jBackend backend) {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
@ -171,9 +168,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStringHashLabels(Nd4jBackend backend) { public void testStringHashLabels(Nd4jBackend backend) {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
@ -190,9 +186,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvalMasking(Nd4jBackend backend) { public void testEvalMasking(Nd4jBackend backend) {
int miniBatch = 5; int miniBatch = 5;
int nOut = 3; int nOut = 3;
@ -259,9 +254,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testFalsePerfectRecall(Nd4jBackend backend) { public void testFalsePerfectRecall(Nd4jBackend backend) {
int testSize = 100; int testSize = 100;
int numClasses = 5; int numClasses = 5;
@ -294,9 +288,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
assertNotEquals(1.0, eval.recall()); assertNotEquals(1.0, eval.recall());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationMerging(Nd4jBackend backend) { public void testEvaluationMerging(Nd4jBackend backend) {
int nRows = 20; int nRows = 20;
@ -370,9 +363,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSingleClassBinaryClassification(Nd4jBackend backend) { public void testSingleClassBinaryClassification(Nd4jBackend backend) {
Evaluation eval = new Evaluation(1); Evaluation eval = new Evaluation(1);
@ -401,9 +393,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvalInvalid(Nd4jBackend backend) { public void testEvalInvalid(Nd4jBackend backend) {
Evaluation e = new Evaluation(5); Evaluation e = new Evaluation(5);
e.eval(0, 1); e.eval(0, 1);
@ -416,9 +407,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
assertFalse(e.stats().contains("\uFFFD")); assertFalse(e.stats().contains("\uFFFD"));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvalMethods(Nd4jBackend backend) { public void testEvalMethods(Nd4jBackend backend) {
//Check eval(int,int) vs. eval(INDArray,INDArray) //Check eval(int,int) vs. eval(INDArray,INDArray)
@ -461,9 +451,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTopNAccuracy(Nd4jBackend backend) { public void testTopNAccuracy(Nd4jBackend backend) {
Evaluation e = new Evaluation(null, 3); Evaluation e = new Evaluation(null, 3);
@ -524,9 +513,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTopNAccuracyMerging(Nd4jBackend backend) { public void testTopNAccuracyMerging(Nd4jBackend backend) {
Evaluation e1 = new Evaluation(null, 3); Evaluation e1 = new Evaluation(null, 3);
@ -574,9 +562,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
assertEquals(6.0 / 8, e1.topNAccuracy(), 1e-6); assertEquals(6.0 / 8, e1.topNAccuracy(), 1e-6);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBinaryCase(Nd4jBackend backend) { public void testBinaryCase(Nd4jBackend backend) {
INDArray ones10 = Nd4j.ones(10, 1); INDArray ones10 = Nd4j.ones(10, 1);
INDArray ones4 = Nd4j.ones(4, 1); INDArray ones4 = Nd4j.ones(4, 1);
@ -605,9 +592,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
assertEquals(2, (int) e.truePositives().get(0)); assertEquals(2, (int) e.truePositives().get(0));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testF1FBeta_MicroMacroAveraging(Nd4jBackend backend) { public void testF1FBeta_MicroMacroAveraging(Nd4jBackend backend) {
//Confusion matrix: rows = actual, columns = predicted //Confusion matrix: rows = actual, columns = predicted
//[3, 1, 0] //[3, 1, 0]
@ -748,9 +734,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConfusionMatrixStats(Nd4jBackend backend) { public void testConfusionMatrixStats(Nd4jBackend backend) {
Evaluation e = new Evaluation(); Evaluation e = new Evaluation();
@ -771,9 +756,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvalBinaryMetrics(){ public void testEvalBinaryMetrics(){
Evaluation ePosClass1_nOut2 = new Evaluation(2, 1); Evaluation ePosClass1_nOut2 = new Evaluation(2, 1);
@ -894,9 +878,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConfusionMatrixString(){ public void testConfusionMatrixString(){
Evaluation e = new Evaluation(Arrays.asList("a","b","c")); Evaluation e = new Evaluation(Arrays.asList("a","b","c"));
@ -946,9 +929,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
e.stats(false, true); e.stats(false, true);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationNaNs(){ public void testEvaluationNaNs(){
Evaluation e = new Evaluation(); Evaluation e = new Evaluation();
@ -963,9 +945,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentation(){ public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1059,9 +1040,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLabelReset(){ public void testLabelReset(){
Map<Integer,String> m = new HashMap<>(); Map<Integer,String> m = new HashMap<>();
@ -1094,9 +1074,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
assertEquals(s1, s2); assertEquals(s1, s2);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvalStatsBinaryCase(){ public void testEvalStatsBinaryCase(){
//Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case //Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case

View File

@ -48,9 +48,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationBinary(Nd4jBackend backend) { public void testEvaluationBinary(Nd4jBackend backend) {
//Compare EvaluationBinary to Evaluation class //Compare EvaluationBinary to Evaluation class
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
@ -136,9 +135,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationBinaryMerging(Nd4jBackend backend) { public void testEvaluationBinaryMerging(Nd4jBackend backend) {
int nOut = 4; int nOut = 4;
int[] shape1 = {30, nOut}; int[] shape1 = {30, nOut};
@ -165,9 +163,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
assertEquals(eb.stats(), eb1.stats()); assertEquals(eb.stats(), eb1.stats());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationBinaryPerOutputMasking(Nd4jBackend backend) { public void testEvaluationBinaryPerOutputMasking(Nd4jBackend backend) {
//Provide a mask array: "ignore" the masked steps //Provide a mask array: "ignore" the masked steps
@ -210,9 +207,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
assertEquals(1, eb.falseNegatives(2)); assertEquals(1, eb.falseNegatives(2));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTimeSeriesEval(Nd4jBackend backend) { public void testTimeSeriesEval(Nd4jBackend backend) {
int[] shape = {2, 4, 3}; int[] shape = {2, 4, 3};
@ -236,9 +232,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
assertEquals(eb2.stats(), eb1.stats()); assertEquals(eb2.stats(), eb1.stats());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationBinaryWithROC(Nd4jBackend backend) { public void testEvaluationBinaryWithROC(Nd4jBackend backend) {
//Simple test for nested ROCBinary in EvaluationBinary //Simple test for nested ROCBinary in EvaluationBinary
@ -255,9 +250,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationBinary3d(Nd4jBackend backend) { public void testEvaluationBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -291,9 +285,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationBinary4d(Nd4jBackend backend) { public void testEvaluationBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -327,9 +320,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationBinary3dMasking(Nd4jBackend backend) { public void testEvaluationBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -390,9 +382,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationBinary4dMasking(Nd4jBackend backend) { public void testEvaluationBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -49,9 +49,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReliabilityDiagram (Nd4jBackend backend) { public void testReliabilityDiagram (Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
@ -143,9 +142,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLabelAndPredictionCounts (Nd4jBackend backend) { public void testLabelAndPredictionCounts (Nd4jBackend backend) {
int minibatch = 50; int minibatch = 50;
@ -173,9 +171,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
assertArrayEquals(expPredictionCount, ec.getPredictionCountsEachClass()); assertArrayEquals(expPredictionCount, ec.getPredictionCountsEachClass());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testResidualPlots (Nd4jBackend backend) { public void testResidualPlots (Nd4jBackend backend) {
int minibatch = 50; int minibatch = 50;
@ -276,9 +273,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentation(){ public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -372,9 +368,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationCalibration3d (Nd4jBackend backend) { public void testEvaluationCalibration3d (Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -406,9 +401,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
assertEquals(e2d.stats(), e3d.stats()); assertEquals(e2d.stats(), e3d.stats());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvaluationCalibration3dMasking (Nd4jBackend backend) { public void testEvaluationCalibration3dMasking (Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);

View File

@ -46,9 +46,8 @@ public class NewInstanceTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNewInstances(Nd4jBackend backend) { public void testNewInstances(Nd4jBackend backend) {
boolean print = true; boolean print = true;
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -48,9 +48,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary(Nd4jBackend backend) { public void testROCBinary(Nd4jBackend backend) {
//Compare ROCBinary to ROC class //Compare ROCBinary to ROC class
@ -145,9 +144,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBinaryMerging(Nd4jBackend backend) { public void testRocBinaryMerging(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact for (int nSteps : new int[]{30, 0}) { //0 == exact
int nOut = 4; int nOut = 4;
@ -177,9 +175,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) { public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact for (int nSteps : new int[]{30, 0}) { //0 == exact
@ -219,9 +216,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary3d(Nd4jBackend backend) { public void testROCBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -255,9 +251,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary4d(Nd4jBackend backend) { public void testROCBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -291,9 +286,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary3dMasking(Nd4jBackend backend) { public void testROCBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -354,9 +348,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary4dMasking(Nd4jBackend backend) { public void testROCBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -82,9 +82,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
expFPR.put(10 / 10.0, 0.0 / totalNegatives); expFPR.put(10 / 10.0, 0.0 / totalNegatives);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBasic(Nd4jBackend backend) { public void testRocBasic(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax) //2 outputs here - probability distribution over classes (softmax)
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
@ -127,9 +126,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
assertEquals(1.0, auc, 1e-6); assertEquals(1.0, auc, 1e-6);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBasicSingleClass(Nd4jBackend backend) { public void testRocBasicSingleClass(Nd4jBackend backend) {
//1 output here - single probability value (sigmoid) //1 output here - single probability value (sigmoid)
@ -167,9 +165,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRoc(Nd4jBackend backend) { public void testRoc(Nd4jBackend backend) {
//Previous tests allowed for a perfect classifier with right threshold... //Previous tests allowed for a perfect classifier with right threshold...
@ -254,9 +251,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) { public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
//Same as first test... //Same as first test...
@ -303,9 +299,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocTimeSeriesMasking(Nd4jBackend backend) { public void testRocTimeSeriesMasking(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax) //2 outputs here - probability distribution over classes (softmax)
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
@ -355,9 +350,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) { public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -387,9 +381,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCompare2Vs3Classes(Nd4jBackend backend) { public void testCompare2Vs3Classes(Nd4jBackend backend) {
//ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together... //ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together...
@ -438,9 +431,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMerging(Nd4jBackend backend) { public void testROCMerging(Nd4jBackend backend) {
int nArrays = 10; int nArrays = 10;
int minibatch = 64; int minibatch = 64;
@ -485,9 +477,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMerging2(Nd4jBackend backend) { public void testROCMerging2(Nd4jBackend backend) {
int nArrays = 10; int nArrays = 10;
int minibatch = 64; int minibatch = 64;
@ -532,9 +523,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMultiMerging(Nd4jBackend backend) { public void testROCMultiMerging(Nd4jBackend backend) {
int nArrays = 10; int nArrays = 10;
@ -582,9 +572,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAUCPrecisionRecall(Nd4jBackend backend) { public void testAUCPrecisionRecall(Nd4jBackend backend) {
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob //Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
//at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0 //at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0
@ -631,9 +620,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocAucExact(Nd4jBackend backend) { public void testRocAucExact(Nd4jBackend backend) {
//Check the implementation vs. Scikitlearn //Check the implementation vs. Scikitlearn
@ -796,9 +784,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) { public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
//Set reallocation block size to say 20, but then evaluate a 100-length array //Set reallocation block size to say 20, but then evaluate a 100-length array
@ -810,9 +797,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) { public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
double[] threshold = new double[101]; double[] threshold = new double[101];
double[] precision = threshold; double[] precision = threshold;
@ -848,9 +834,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) { public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
//Sanity check: values calculated from the confusion matrix should match the PR curve values //Sanity check: values calculated from the confusion matrix should match the PR curve values
@ -889,9 +874,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocMerge(){ public void testRocMerge(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -935,9 +919,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
assertEquals(auprc, auprcAct, 1e-6); assertEquals(auprc, auprcAct, 1e-6);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocMultiMerge(){ public void testRocMultiMerge(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -986,9 +969,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBinaryMerge(){ public void testRocBinaryMerge(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1033,9 +1015,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentationBinary(){ public void testSegmentationBinary(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1125,9 +1106,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentation(){ public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -63,9 +63,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPerfectPredictions(Nd4jBackend backend) { public void testPerfectPredictions(Nd4jBackend backend) {
int nCols = 5; int nCols = 5;
@ -92,9 +91,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testKnownValues(Nd4jBackend backend) { public void testKnownValues(Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
@ -150,9 +148,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRegressionEvaluationMerging(Nd4jBackend backend) { public void testRegressionEvaluationMerging(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -193,9 +190,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRegressionEvalPerOutputMasking(Nd4jBackend backend) { public void testRegressionEvalPerOutputMasking(Nd4jBackend backend) {
INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}}); INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}});
@ -222,9 +218,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRegressionEvalTimeSeriesSplit(){ public void testRegressionEvalTimeSeriesSplit(){
INDArray out1 = Nd4j.rand(new int[]{3, 5, 20}); INDArray out1 = Nd4j.rand(new int[]{3, 5, 20});
@ -246,9 +241,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
assertEquals(e1, e2); assertEquals(e1, e2);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRegressionEval3d(Nd4jBackend backend) { public void testRegressionEval3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -280,9 +274,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRegressionEval4d(Nd4jBackend backend) { public void testRegressionEval4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -314,9 +307,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRegressionEval3dMasking(Nd4jBackend backend) { public void testRegressionEval3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -375,9 +367,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRegressionEval4dMasking(Nd4jBackend backend) { public void testRegressionEval4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -44,9 +44,8 @@ public class TestLegacyJsonLoading extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvalLegacyFormat(Nd4jBackend backend) throws Exception { public void testEvalLegacyFormat(Nd4jBackend backend) throws Exception {
File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile(); File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile();

View File

@ -60,9 +60,8 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSingleDeviceAveraging1(Nd4jBackend backend) { public void testSingleDeviceAveraging1(Nd4jBackend backend) {
INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0); INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0);
INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0); INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0);
@ -109,9 +108,8 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
assertEquals(arrayMean, array16); assertEquals(arrayMean, array16);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSingleDeviceAveraging2(Nd4jBackend backend) { public void testSingleDeviceAveraging2(Nd4jBackend backend) {
INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH); INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH);
List<INDArray> arrays = new ArrayList<>(); List<INDArray> arrays = new ArrayList<>();
@ -128,9 +126,8 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAccumulation1(Nd4jBackend backend) { public void testAccumulation1(Nd4jBackend backend) {
INDArray array1 = Nd4j.create(100).assign(1.0); INDArray array1 = Nd4j.create(100).assign(1.0);
INDArray array2 = Nd4j.create(100).assign(2.0); INDArray array2 = Nd4j.create(100).assign(2.0);
@ -143,9 +140,8 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAccumulation2(Nd4jBackend backend) { public void testAccumulation2(Nd4jBackend backend) {
INDArray array1 = Nd4j.create(100).assign(1.0); INDArray array1 = Nd4j.create(100).assign(1.0);
INDArray array2 = Nd4j.create(100).assign(2.0); INDArray array2 = Nd4j.create(100).assign(2.0);
@ -160,9 +156,8 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAccumulation3(Nd4jBackend backend) { public void testAccumulation3(Nd4jBackend backend) {
// we want to ensure that cuda backend is able to launch this op on cpu // we want to ensure that cuda backend is able to launch this op on cpu
Nd4j.getAffinityManager().allowCrossDeviceAccess(false); Nd4j.getAffinityManager().allowCrossDeviceAccess(false);

View File

@ -39,9 +39,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j @Slf4j
public class DataTypeTest extends BaseNd4jTestWithBackends { public class DataTypeTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDataTypes(Nd4jBackend backend) throws Exception { public void testDataTypes(Nd4jBackend backend) throws Exception {
for (val type : DataType.values()) { for (val type : DataType.values()) {
if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type)) if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type))

View File

@ -41,9 +41,8 @@ public class InputValidationTests extends BaseNd4jTestWithBackends {
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
///////////////////// Broadcast Tests /////////////////////// ///////////////////// Broadcast Tests ///////////////////////
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInvalidColVectorOp1(Nd4jBackend backend) { public void testInvalidColVectorOp1(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray col = Nd4j.create(5, 1); INDArray col = Nd4j.create(5, 1);
@ -55,9 +54,8 @@ public class InputValidationTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInvalidColVectorOp2(Nd4jBackend backend) { public void testInvalidColVectorOp2(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray col = Nd4j.create(5, 1); INDArray col = Nd4j.create(5, 1);
@ -69,9 +67,8 @@ public class InputValidationTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInvalidRowVectorOp1(Nd4jBackend backend) { public void testInvalidRowVectorOp1(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray row = Nd4j.create(1, 5); INDArray row = Nd4j.create(1, 5);
@ -83,9 +80,8 @@ public class InputValidationTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInvalidRowVectorOp2(Nd4jBackend backend) { public void testInvalidRowVectorOp2(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray row = Nd4j.create(1, 5); INDArray row = Nd4j.create(1, 5);

View File

@ -51,9 +51,8 @@ import static org.junit.jupiter.api.Assertions.*;
public class LoneTest extends BaseNd4jTestWithBackends { public class LoneTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSoftmaxStability(Nd4jBackend backend) { public void testSoftmaxStability(Nd4jBackend backend) {
INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose(); INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose();
// System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo())); // System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
@ -67,9 +66,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testFlattenedView(Nd4jBackend backend) { public void testFlattenedView(Nd4jBackend backend) {
int rows = 8; int rows = 8;
int cols = 8; int cols = 8;
@ -105,9 +103,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
assertEquals(fAssertion, Nd4j.toFlattened('f', first)); assertEquals(fAssertion, Nd4j.toFlattened('f', first));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIndexingColVec(Nd4jBackend backend) { public void testIndexingColVec(Nd4jBackend backend) {
int elements = 5; int elements = 5;
INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements); INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements);
@ -126,9 +123,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void concatScalarVectorIssue(Nd4jBackend backend) { public void concatScalarVectorIssue(Nd4jBackend backend) {
//A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars //A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars
INDArray arr1 = Nd4j.create(1, 1); INDArray arr1 = Nd4j.create(1, 1);
@ -138,9 +134,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
assertTrue(arr4.sumNumber().floatValue() <= Nd4j.EPS_THRESHOLD); assertTrue(arr4.sumNumber().floatValue() <= Nd4j.EPS_THRESHOLD);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void reshapeTensorMmul(Nd4jBackend backend) { public void reshapeTensorMmul(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2); INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2);
INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2); INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2);
@ -152,9 +147,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
INDArray c = Nd4j.tensorMmul(b, a, axes); INDArray c = Nd4j.tensorMmul(b, a, axes);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void maskWhenMerge(Nd4jBackend backend) { public void maskWhenMerge(Nd4jBackend backend) {
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5)); DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3)); DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
@ -169,9 +163,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRelu(Nd4jBackend backend) { public void testRelu(Nd4jBackend backend) {
INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4); INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4); INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
@ -197,9 +190,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
assertEquals(max - 1, currentArgMax); assertEquals(max - 1, currentArgMax);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRPF(Nd4jBackend backend) { public void testRPF(Nd4jBackend backend) {
val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3); val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3);
@ -212,9 +204,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
log.info("TAD:\n{}", tad); log.info("TAD:\n{}", tad);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConcat3D_Vstack_C(Nd4jBackend backend) { public void testConcat3D_Vstack_C(Nd4jBackend backend) {
val shape = new long[]{1, 1000, 20}; val shape = new long[]{1, 1000, 20};
@ -244,9 +235,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetRow1(Nd4jBackend backend) { public void testGetRow1(Nd4jBackend backend) {
INDArray array = Nd4j.create(10000, 10000); INDArray array = Nd4j.create(10000, 10000);
@ -285,9 +275,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void checkSliceofSlice(Nd4jBackend backend) { public void checkSliceofSlice(Nd4jBackend backend) {
/* /*
Issue 1: Slice of slice with c order and f order views are not equal Issue 1: Slice of slice with c order and f order views are not equal
@ -327,9 +316,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void checkWithReshape(Nd4jBackend backend) { public void checkWithReshape(Nd4jBackend backend) {
INDArray arr = Nd4j.create(1, 3); INDArray arr = Nd4j.create(1, 3);
INDArray reshaped = arr.reshape('f', 3, 1); INDArray reshaped = arr.reshape('f', 3, 1);

View File

@ -38,9 +38,8 @@ public class MmulBug extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void simpleTest(Nd4jBackend backend) { public void simpleTest(Nd4jBackend backend) {
INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}}); INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}});

View File

@ -63,9 +63,8 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class NDArrayTestsFortran extends BaseNd4jTestWithBackends { public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScalarOps(Nd4jBackend backend) { public void testScalarOps(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3}); INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
assertEquals(27d, n.length(), 1e-1); assertEquals(27d, n.length(), 1e-1);
@ -83,9 +82,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testColumnMmul(Nd4jBackend backend) { public void testColumnMmul(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 10, 18, DataType.FLOAT).data(); DataBuffer data = Nd4j.linspace(1, 10, 18, DataType.FLOAT).data();
INDArray x2 = Nd4j.create(data, new long[] {2, 3, 3}); INDArray x2 = Nd4j.create(data, new long[] {2, 3, 3});
@ -116,9 +114,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRowVectorGemm(Nd4jBackend backend) { public void testRowVectorGemm(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE);
INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE); INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE);
@ -129,18 +126,16 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRepmat(Nd4jBackend backend) { public void testRepmat(Nd4jBackend backend) {
INDArray rowVector = Nd4j.create(1, 4); INDArray rowVector = Nd4j.create(1, 4);
INDArray repmat = rowVector.repmat(4, 4); INDArray repmat = rowVector.repmat(4, 4);
assertTrue(Arrays.equals(new long[] {4, 16}, repmat.shape())); assertTrue(Arrays.equals(new long[] {4, 16}, repmat.shape()));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReadWrite() throws Exception { public void testReadWrite() throws Exception {
INDArray write = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray write = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -155,9 +150,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReadWriteDouble() throws Exception { public void testReadWriteDouble() throws Exception {
INDArray write = Nd4j.linspace(1, 4, 4, DataType.FLOAT); INDArray write = Nd4j.linspace(1, 4, 4, DataType.FLOAT);
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -173,9 +167,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMultiThreading() throws Exception { public void testMultiThreading() throws Exception {
ExecutorService ex = ExecutorServiceProvider.getExecutorService(); ExecutorService ex = ExecutorServiceProvider.getExecutorService();
@ -195,9 +188,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBroadcastingGenerated(Nd4jBackend backend) { public void testBroadcastingGenerated(Nd4jBackend backend) {
int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10); int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10);
List<List<Pair<INDArray, String>>> broadCastList = new ArrayList<>(broadcastShape.length); List<List<Pair<INDArray, String>>> broadCastList = new ArrayList<>(broadcastShape.length);
@ -222,9 +214,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBroadCasting(Nd4jBackend backend) { public void testBroadCasting(Nd4jBackend backend) {
INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE); INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE);
INDArray ret = first.broadcast(3, 4); INDArray ret = first.broadcast(3, 4);
@ -237,18 +228,16 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOneTensor(Nd4jBackend backend) { public void testOneTensor(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1); INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1);
INDArray matrixToBroadcast = Nd4j.ones(1, 1); INDArray matrixToBroadcast = Nd4j.ones(1, 1);
assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr); assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSortWithIndicesDescending(Nd4jBackend backend) { public void testSortWithIndicesDescending(Nd4jBackend backend) {
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
//indices,data //indices,data
@ -259,9 +248,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
assertEquals(shouldIndex, sorted[0],getFailureMessage()); assertEquals(shouldIndex, sorted[0],getFailureMessage());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSortDeadlock(Nd4jBackend backend) { public void testSortDeadlock(Nd4jBackend backend) {
val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768); val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768);
@ -269,9 +257,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSortWithIndices(Nd4jBackend backend) { public void testSortWithIndices(Nd4jBackend backend) {
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
//indices,data //indices,data
@ -282,18 +269,16 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
assertEquals(shouldIndex, sorted[0],getFailureMessage()); assertEquals(shouldIndex, sorted[0],getFailureMessage());
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNd4jSortScalar(Nd4jBackend backend) { public void testNd4jSortScalar(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1); INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1);
INDArray sorted = Nd4j.sort(linspace, 1, false); INDArray sorted = Nd4j.sort(linspace, 1, false);
// System.out.println(sorted); // System.out.println(sorted);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSwapAxesFortranOrder(Nd4jBackend backend) { public void testSwapAxesFortranOrder(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE); INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE);
for (int i = 0; i < n.slices(); i++) { for (int i = 0; i < n.slices(); i++) {
@ -312,9 +297,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDimShuffle(Nd4jBackend backend) { public void testDimShuffle(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false}); INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false});
@ -325,9 +309,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetVsGetScalar(Nd4jBackend backend) { public void testGetVsGetScalar(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
float element = a.getFloat(0, 1); float element = a.getFloat(0, 1);
@ -340,9 +323,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDivide(Nd4jBackend backend) { public void testDivide(Nd4jBackend backend) {
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2}); INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
INDArray div = two.div(two); INDArray div = two.div(two);
@ -356,9 +338,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSigmoid(Nd4jBackend backend) { public void testSigmoid(Nd4jBackend backend) {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f}); INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
@ -367,9 +348,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNeg(Nd4jBackend backend) { public void testNeg(Nd4jBackend backend) {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4}); INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
@ -379,9 +359,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCosineSim(Nd4jBackend backend) { public void testCosineSim(Nd4jBackend backend) {
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
@ -396,9 +375,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testExp(Nd4jBackend backend) { public void testExp(Nd4jBackend backend) {
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f}); INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f});
@ -408,9 +386,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScalar(Nd4jBackend backend) { public void testScalar(Nd4jBackend backend) {
INDArray a = Nd4j.scalar(1.0f); INDArray a = Nd4j.scalar(1.0f);
assertEquals(true, a.isScalar()); assertEquals(true, a.isScalar());
@ -422,9 +399,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testWrap(Nd4jBackend backend) { public void testWrap(Nd4jBackend backend) {
int[] shape = {2, 4}; int[] shape = {2, 4};
INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]); INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]);
@ -449,9 +425,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
assertEquals(row22.columns(), 2); assertEquals(row22.columns(), 2);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetRowFortran(Nd4jBackend backend) { public void testGetRowFortran(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2}); INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2});
INDArray column = Nd4j.create(new float[] {1, 3}); INDArray column = Nd4j.create(new float[] {1, 3});
@ -464,9 +439,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetColumnFortran(Nd4jBackend backend) { public void testGetColumnFortran(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2}); INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
INDArray column = Nd4j.create(new double[] {1, 2}); INDArray column = Nd4j.create(new double[] {1, 2});
@ -480,9 +454,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetColumns(Nd4jBackend backend) { public void testGetColumns(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE); INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE);
// log.info("Original: {}", matrix); // log.info("Original: {}", matrix);
@ -496,9 +469,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVectorInit(Nd4jBackend backend) { public void testVectorInit(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(); DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data();
INDArray arr = Nd4j.create(data, new long[] {1, 4}); INDArray arr = Nd4j.create(data, new long[] {1, 4});
@ -511,9 +483,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAssignOffset(Nd4jBackend backend) { public void testAssignOffset(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(5, 5); INDArray arr = Nd4j.ones(5, 5);
INDArray row = arr.slice(1); INDArray row = arr.slice(1);
@ -521,9 +492,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
assertEquals(Nd4j.ones(5), row); assertEquals(Nd4j.ones(5), row);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testColumns(Nd4jBackend backend) { public void testColumns(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE); INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE);
INDArray column = Nd4j.create(new double[] {1, 2, 3}); INDArray column = Nd4j.create(new double[] {1, 2, 3});
@ -561,9 +531,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutRow(Nd4jBackend backend) { public void testPutRow(Nd4jBackend backend) {
INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray n = d.dup(); INDArray n = d.dup();
@ -622,9 +591,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInplaceTranspose(Nd4jBackend backend) { public void testInplaceTranspose(Nd4jBackend backend) {
INDArray test = Nd4j.rand(3, 4); INDArray test = Nd4j.rand(3, 4);
INDArray orig = test.dup(); INDArray orig = test.dup();
@ -639,9 +607,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMmulF(Nd4jBackend backend) { public void testMmulF(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data(); DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data();
@ -659,9 +626,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRowsColumns(Nd4jBackend backend) { public void testRowsColumns(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data(); DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data();
INDArray rows = Nd4j.create(data, new long[] {2, 3}); INDArray rows = Nd4j.create(data, new long[] {2, 3});
@ -677,9 +643,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTranspose(Nd4jBackend backend) { public void testTranspose(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4}); INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4});
INDArray transpose = n.transpose(); INDArray transpose = n.transpose();
@ -707,9 +672,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAddMatrix(Nd4jBackend backend) { public void testAddMatrix(Nd4jBackend backend) {
INDArray five = Nd4j.ones(5); INDArray five = Nd4j.ones(5);
five.addi(five.dup()); five.addi(five.dup());
@ -720,9 +684,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMMul(Nd4jBackend backend) { public void testMMul(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
@ -733,9 +696,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutSlice(Nd4jBackend backend) { public void testPutSlice(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3); INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3);
INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3); INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3);
@ -746,9 +708,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRowVectorMultipleIndices(Nd4jBackend backend) { public void testRowVectorMultipleIndices(Nd4jBackend backend) {
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4); INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
linear.putScalar(new long[] {0, 1}, 1); linear.putScalar(new long[] {0, 1}, 1);
@ -757,9 +718,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDim1(Nd4jBackend backend) { public void testDim1(Nd4jBackend backend) {
INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1); INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1);
INDArray same = sum.dup(); INDArray same = sum.dup();
@ -767,9 +727,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEps(Nd4jBackend backend) { public void testEps(Nd4jBackend backend) {
val ones = Nd4j.ones(5); val ones = Nd4j.ones(5);
val res = Nd4j.createUninitialized(DataType.BOOL, 5); val res = Nd4j.createUninitialized(DataType.BOOL, 5);
@ -777,9 +736,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLogDouble(Nd4jBackend backend) { public void testLogDouble(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE);
INDArray log = Transforms.log(linspace); INDArray log = Transforms.log(linspace);
@ -787,36 +745,32 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
assertEquals(assertion, log); assertEquals(assertion, log);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVectorSum(Nd4jBackend backend) { public void testVectorSum(Nd4jBackend backend) {
INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1); assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVectorSum2(Nd4jBackend backend) { public void testVectorSum2(Nd4jBackend backend) {
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1); assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVectorSum3(Nd4jBackend backend) { public void testVectorSum3(Nd4jBackend backend) {
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals(lin, lin2); assertEquals(lin, lin2);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSmallSum(Nd4jBackend backend) { public void testSmallSum(Nd4jBackend backend) {
INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007}); INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007});
base.addi(1e-12); base.addi(1e-12);
@ -827,9 +781,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPermute(Nd4jBackend backend) { public void testPermute(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4}); INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4});
INDArray transpose = n.transpose(); INDArray transpose = n.transpose();
@ -858,9 +811,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAppendBias(Nd4jBackend backend) { public void testAppendBias(Nd4jBackend backend) {
INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose(); INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray test = Nd4j.appendBias(rand); INDArray test = Nd4j.appendBias(rand);
@ -868,9 +820,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
assertEquals(assertion, test); assertEquals(assertion, test);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRand(Nd4jBackend backend) { public void testRand(Nd4jBackend backend) {
INDArray rand = Nd4j.randn(5, 5); INDArray rand = Nd4j.randn(5, 5);
Nd4j.getDistributions().createUniform(0.4, 4).sample(5); Nd4j.getDistributions().createUniform(0.4, 4).sample(5);
@ -882,9 +833,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIdentity(Nd4jBackend backend) { public void testIdentity(Nd4jBackend backend) {
INDArray eye = Nd4j.eye(5); INDArray eye = Nd4j.eye(5);
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape())); assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
@ -895,9 +845,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testColumnVectorOpsFortran(Nd4jBackend backend) { public void testColumnVectorOpsFortran(Nd4jBackend backend) {
INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1}); INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1});
@ -908,9 +857,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRSubi(Nd4jBackend backend) { public void testRSubi(Nd4jBackend backend) {
INDArray n2 = Nd4j.ones(2); INDArray n2 = Nd4j.ones(2);
INDArray n2Assertion = Nd4j.zeros(2); INDArray n2Assertion = Nd4j.zeros(2);
@ -920,9 +868,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAssign(Nd4jBackend backend) { public void testAssign(Nd4jBackend backend) {
INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
vector.assign(1); vector.assign(1);
@ -939,9 +886,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
assertEquals(tensor, ones); assertEquals(tensor, ones);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAddScalar(Nd4jBackend backend) { public void testAddScalar(Nd4jBackend backend) {
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0); INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
INDArray rdiv = div.add(1); INDArray rdiv = div.add(1);
@ -949,9 +895,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
assertEquals(answer, rdiv); assertEquals(answer, rdiv);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRdivScalar(Nd4jBackend backend) { public void testRdivScalar(Nd4jBackend backend) {
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0); INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
INDArray rdiv = div.rdiv(1); INDArray rdiv = div.rdiv(1);
@ -959,9 +904,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
assertEquals(rdiv, answer); assertEquals(rdiv, answer);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRDivi(Nd4jBackend backend) { public void testRDivi(Nd4jBackend backend) {
INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0); INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0);
INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5); INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5);
@ -971,9 +915,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNumVectorsAlongDimension(Nd4jBackend backend) { public void testNumVectorsAlongDimension(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2);
assertEquals(12, arr.vectorsAlongDimension(2)); assertEquals(12, arr.vectorsAlongDimension(2));
@ -981,9 +924,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBroadCast(Nd4jBackend backend) { public void testBroadCast(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
INDArray broadCasted = n.broadcast(5, 4); INDArray broadCasted = n.broadcast(5, 4);
@ -1005,9 +947,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
assertTrue(Arrays.equals(new long[] {1, 2, 36, 36}, broadCasted3.shape())); assertTrue(Arrays.equals(new long[] {1, 2, 36, 36}, broadCasted3.shape()));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMatrix(Nd4jBackend backend) { public void testMatrix(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2}); INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2});
@ -1017,9 +958,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutRowGetRowOrdering(Nd4jBackend backend) { public void testPutRowGetRowOrdering(Nd4jBackend backend) {
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray put = Nd4j.create(new double[] {5, 6}); INDArray put = Nd4j.create(new double[] {5, 6});
@ -1041,9 +981,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSumWithRow1(Nd4jBackend backend) { public void testSumWithRow1(Nd4jBackend backend) {
//Works: //Works:
INDArray array2d = Nd4j.ones(1, 10); INDArray array2d = Nd4j.ones(1, 10);
@ -1074,9 +1013,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
array5d.sum(4); //java.lang.IllegalArgumentException: Illegal index 10000 derived from 9 with offset of 1000 and stride of 1000 array5d.sum(4); //java.lang.IllegalArgumentException: Illegal index 10000 derived from 9 with offset of 1000 and stride of 1000
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSumWithRow2(Nd4jBackend backend) { public void testSumWithRow2(Nd4jBackend backend) {
//All sums in this method execute without exceptions. //All sums in this method execute without exceptions.
INDArray array3d = Nd4j.ones(2, 10, 10); INDArray array3d = Nd4j.ones(2, 10, 10);
@ -1099,9 +1037,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutRowFortran(Nd4jBackend backend) { public void testPutRowFortran(Nd4jBackend backend) {
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE); INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE);
INDArray put = Nd4j.create(new double[] {5, 6}); INDArray put = Nd4j.create(new double[] {5, 6});
@ -1114,9 +1051,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testElementWiseOps(Nd4jBackend backend) { public void testElementWiseOps(Nd4jBackend backend) {
INDArray n1 = Nd4j.scalar(1); INDArray n1 = Nd4j.scalar(1);
INDArray n2 = Nd4j.scalar(2); INDArray n2 = Nd4j.scalar(2);
@ -1139,9 +1075,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRollAxis(Nd4jBackend backend) { public void testRollAxis(Nd4jBackend backend) {
INDArray toRoll = Nd4j.ones(3, 4, 5, 6); INDArray toRoll = Nd4j.ones(3, 4, 5, 6);
assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape()); assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape());
@ -1163,9 +1098,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNegativeShape(Nd4jBackend backend) { public void testNegativeShape(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
INDArray reshaped = linspace.reshape(-1, 2); INDArray reshaped = linspace.reshape(-1, 2);
@ -1177,9 +1111,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetColumnGetRow(Nd4jBackend backend) { public void testGetColumnGetRow(Nd4jBackend backend) {
INDArray row = Nd4j.ones(1, 5); INDArray row = Nd4j.ones(1, 5);
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
@ -1194,9 +1127,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDupAndDupWithOrder(Nd4jBackend backend) { public void testDupAndDupWithOrder(Nd4jBackend backend) {
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE); List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
int count = 0; int count = 0;
@ -1218,9 +1150,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testToOffsetZeroCopy(Nd4jBackend backend) { public void testToOffsetZeroCopy(Nd4jBackend backend) {
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE); List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);

View File

@ -69,9 +69,8 @@ public class Nd4jTestsComparisonC extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) { public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);

View File

@ -71,9 +71,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
return 'f'; return 'f';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCrash(Nd4jBackend backend) { public void testCrash(Nd4jBackend backend) {
INDArray array3d = Nd4j.ones(1, 10, 10); INDArray array3d = Nd4j.ones(1, 10, 10);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0); Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0);
@ -83,9 +82,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array4d, 0); Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array4d, 0);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMmulWithOpsCommonsMath(Nd4jBackend backend) { public void testMmulWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
@ -100,9 +98,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) { public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
@ -158,9 +155,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemvApacheCommons(Nd4jBackend backend) { public void testGemvApacheCommons(Nd4jBackend backend) {
int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8}; int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8};
@ -215,9 +211,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAddSubtractWithOpsCommonsMath(Nd4jBackend backend) { public void testAddSubtractWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
@ -235,9 +230,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMulDivOnCheckUtilMatrices(Nd4jBackend backend) { public void testMulDivOnCheckUtilMatrices(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);

View File

@ -42,9 +42,8 @@ public class Nd4jTestsF extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType(); DataType initialType = Nd4j.dataType();
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConcat3D_Vstack_F(Nd4jBackend backend) { public void testConcat3D_Vstack_F(Nd4jBackend backend) {
//Nd4j.getExecutioner().enableVerboseMode(true); //Nd4j.getExecutioner().enableVerboseMode(true);
//Nd4j.getExecutioner().enableDebugMode(true); //Nd4j.getExecutioner().enableDebugMode(true);
@ -76,9 +75,8 @@ public class Nd4jTestsF extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSlice_1(Nd4jBackend backend) { public void testSlice_1(Nd4jBackend backend) {
val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1); val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1);
val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1}); val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1});

View File

@ -32,15 +32,14 @@ import org.nd4j.common.util.ArrayUtil;
import java.util.*; import java.util.*;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class ShufflesTests extends BaseNd4jTestWithBackends { public class ShufflesTests extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimpleShuffle1(Nd4jBackend backend) { public void testSimpleShuffle1(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(10, 10); INDArray array = Nd4j.zeros(10, 10);
for (int x = 0; x < 10; x++) { for (int x = 0; x < 10; x++) {
@ -62,9 +61,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
assertTrue(scanner.compareRow(array)); assertTrue(scanner.compareRow(array));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimpleShuffle2(Nd4jBackend backend) { public void testSimpleShuffle2(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(10, 10); INDArray array = Nd4j.zeros(10, 10);
for (int x = 0; x < 10; x++) { for (int x = 0; x < 10; x++) {
@ -79,9 +77,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
assertTrue(scanner.compareColumn(array)); assertTrue(scanner.compareColumn(array));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimpleShuffle3(Nd4jBackend backend) { public void testSimpleShuffle3(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(11, 10); INDArray array = Nd4j.zeros(11, 10);
for (int x = 0; x < 11; x++) { for (int x = 0; x < 11; x++) {
@ -97,9 +94,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
assertTrue(scanner.compareRow(array)); assertTrue(scanner.compareRow(array));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSymmetricShuffle1(Nd4jBackend backend) { public void testSymmetricShuffle1(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10); INDArray features = Nd4j.zeros(10, 10);
INDArray labels = Nd4j.zeros(10, 3); INDArray labels = Nd4j.zeros(10, 3);
@ -137,9 +133,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSymmetricShuffle2(Nd4jBackend backend) { public void testSymmetricShuffle2(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10, 20); INDArray features = Nd4j.zeros(10, 10, 20);
INDArray labels = Nd4j.zeros(10, 10, 3); INDArray labels = Nd4j.zeros(10, 10, 3);
@ -177,9 +172,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSymmetricShuffle3(Nd4jBackend backend) { public void testSymmetricShuffle3(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10, 20); INDArray features = Nd4j.zeros(10, 10, 20);
INDArray featuresMask = Nd4j.zeros(10, 20); INDArray featuresMask = Nd4j.zeros(10, 20);
@ -244,9 +238,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
* There's SMALL chance this test will randomly fail, since spread isn't too big * There's SMALL chance this test will randomly fail, since spread isn't too big
* @throws Exception * @throws Exception
*/ */
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testHalfVectors1(Nd4jBackend backend) { public void testHalfVectors1(Nd4jBackend backend) {
int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20); int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20);
int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20); int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20);
@ -267,9 +260,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInterleavedVector1(Nd4jBackend backend) { public void testInterleavedVector1(Nd4jBackend backend) {
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20); int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20);
int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20); int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20);
@ -290,9 +282,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInterleavedVector3(Nd4jBackend backend) { public void testInterleavedVector3(Nd4jBackend backend) {
for (int e = 0; e < 1000; e++) { for (int e = 0; e < 1000; e++) {
int length = e + 256; //RandomUtils.nextInt(121, 2073); int length = e + 256; //RandomUtils.nextInt(121, 2073);

View File

@ -54,9 +54,8 @@ public class TestEigen extends BaseNd4jTestWithBackends {
// test of functions added by Luke Czapla // test of functions added by Luke Czapla
// Compares solution of A x = L x to solution to A x = L B x when it is simple // Compares solution of A x = L x to solution to A x = L B x when it is simple
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void test2Syev(Nd4jBackend backend) { public void test2Syev(Nd4jBackend backend) {
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
Nd4j.setDefaultDataTypes(dt, dt); Nd4j.setDefaultDataTypes(dt, dt);
@ -75,9 +74,8 @@ public class TestEigen extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSyev(Nd4jBackend backend) { public void testSyev(Nd4jBackend backend) {
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
//log.info("Datatype: {}", dt); //log.info("Datatype: {}", dt);

View File

@ -37,9 +37,8 @@ import org.nd4j.common.util.ArrayUtil;
@Slf4j @Slf4j
public class ToStringTest extends BaseNd4jTestWithBackends { public class ToStringTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testToString(Nd4jBackend backend) throws Exception { public void testToString(Nd4jBackend backend) throws Exception {
assertEquals("[ 1, 2, 3]", assertEquals("[ 1, 2, 3]",
Nd4j.createFromArray(1, 2, 3).toString()); Nd4j.createFromArray(1, 2, 3).toString());
@ -57,9 +56,8 @@ public class ToStringTest extends BaseNd4jTestWithBackends {
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(6, true, 1)); Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(6, true, 1));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testToStringScalars(){ public void testToStringScalars(){
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32}; DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"}; String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};

View File

@ -53,8 +53,9 @@ import java.util.Arrays;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestActivation extends BaseNd4jTestWithBackends { public class TestActivation extends BaseNd4jTestWithBackends {
@ -76,9 +77,8 @@ public class TestActivation extends BaseNd4jTestWithBackends {
mapper.enable(SerializationFeature.INDENT_OUTPUT); mapper.enable(SerializationFeature.INDENT_OUTPUT);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRelu(Nd4jBackend backend){ public void testRelu(Nd4jBackend backend){
Double[] max = {null, 6.0, 2.5, 5.0}; Double[] max = {null, 6.0, 2.5, 5.0};
@ -130,9 +130,8 @@ public class TestActivation extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testJson(Nd4jBackend backend) throws Exception { public void testJson(Nd4jBackend backend) throws Exception {
IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25), IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25),
@ -179,7 +178,7 @@ public class TestActivation extends BaseNd4jTestWithBackends {
for (String s : expFields) { for (String s : expFields) {
msg = "Expected field \"" + s + "\", was not found in " + activations[i].toString(); msg = "Expected field \"" + s + "\", was not found in " + activations[i].toString();
assertTrue(msg, actualFieldsByName.contains(s)); assertTrue(actualFieldsByName.contains(s),msg);
} }
//Test conversion from JSON: //Test conversion from JSON:

View File

@ -30,9 +30,8 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
public class TestBackend extends BaseNd4jTestWithBackends { public class TestBackend extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBuildInfo(Nd4jBackend backend){ public void testBuildInfo(Nd4jBackend backend){
System.out.println("Backend build info: " + backend.buildInfo()); System.out.println("Backend build info: " + backend.buildInfo());
} }

View File

@ -37,9 +37,8 @@ public class TestEnvironment extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEnvironment(Nd4jBackend backend){ public void testEnvironment(Nd4jBackend backend){
Environment e = Nd4j.getEnvironment(); Environment e = Nd4j.getEnvironment();
System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion()); System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion());

View File

@ -44,9 +44,8 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class TestNDArrayCreation extends BaseNd4jTestWithBackends { public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBufferCreation(Nd4jBackend backend) { public void testBufferCreation(Nd4jBackend backend) {
DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2}); DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2});
Pointer pointer = dataBuffer.pointer(); Pointer pointer = dataBuffer.pointer();
@ -68,7 +67,7 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
@Test @Test
@Disabled @Disabled
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCreateNpy() throws Exception { public void testCreateNpy() throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile()); INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile());
assertEquals(2, arrCreate.size(0)); assertEquals(2, arrCreate.size(0));
@ -83,7 +82,7 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
@Test @Test
@Disabled @Disabled
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCreateNpz(Nd4jBackend backend) throws Exception { public void testCreateNpz(Nd4jBackend backend) throws Exception {
Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile()); Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile());
assertEquals(true, map.containsKey("x")); assertEquals(true, map.containsKey("x"));

View File

@ -35,9 +35,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
public class TestNDArrayCreationUtil extends BaseNd4jTestWithBackends { public class TestNDArrayCreationUtil extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testShapes() { public void testShapes() {
long[] shape2d = {2, 3}; long[] shape2d = {2, 3};

View File

@ -32,9 +32,8 @@ import org.nd4j.linalg.factory.Nd4jBackend;
public class TestNamespaces extends BaseNd4jTestWithBackends { public class TestNamespaces extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBitwiseSimple(Nd4jBackend backend){ public void testBitwiseSimple(Nd4jBackend backend){
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT); INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
@ -50,9 +49,8 @@ public class TestNamespaces extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMathSimple(Nd4jBackend backend) { public void testMathSimple(Nd4jBackend backend) {
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1); INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1);
INDArray abs = Nd4j.math.abs(x); INDArray abs = Nd4j.math.abs(x);
@ -67,9 +65,8 @@ public class TestNamespaces extends BaseNd4jTestWithBackends {
// System.out.println(cm); // System.out.println(cm);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRandomSimple(Nd4jBackend backend){ public void testRandomSimple(Nd4jBackend backend){
INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10); INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10);
// System.out.println(normal); // System.out.println(normal);
@ -77,9 +74,8 @@ public class TestNamespaces extends BaseNd4jTestWithBackends {
// System.out.println(uniform); // System.out.println(uniform);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNeuralNetworkSimple(Nd4jBackend backend){ public void testNeuralNetworkSimple(Nd4jBackend backend){
INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10)); INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10));
// System.out.println(out); // System.out.println(out);

View File

@ -36,9 +36,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
public class LapackTest extends BaseNd4jTestWithBackends { public class LapackTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testQRSquare(Nd4jBackend backend) { public void testQRSquare(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9}); INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9});
A = A.reshape('c', 3, 3); A = A.reshape('c', 3, 3);
@ -56,9 +55,8 @@ public class LapackTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testQRRect(Nd4jBackend backend) { public void testQRRect(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
A = A.reshape('f', 4, 3); A = A.reshape('f', 4, 3);
@ -76,9 +74,8 @@ public class LapackTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCholeskyL(Nd4jBackend backend) { public void testCholeskyL(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,}); INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,});
A = A.reshape('c', 3, 3); A = A.reshape('c', 3, 3);
@ -95,9 +92,8 @@ public class LapackTest extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCholeskyU(Nd4jBackend backend) { public void testCholeskyU(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,}); INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,});
A = A.reshape('f', 3, 3); A = A.reshape('f', 3, 3);

View File

@ -39,9 +39,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
public class Level1Test extends BaseNd4jTestWithBackends { public class Level1Test extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDot(Nd4jBackend backend) { public void testDot(Nd4jBackend backend) {
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
@ -54,9 +53,8 @@ public class Level1Test extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAxpy(Nd4jBackend backend) { public void testAxpy(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray row = matrix.getRow(1); INDArray row = matrix.getRow(1);
@ -65,9 +63,8 @@ public class Level1Test extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAxpy2(Nd4jBackend backend) { public void testAxpy2(Nd4jBackend backend) {
val rowX = Nd4j.create(new double[]{1, 2, 3, 4}); val rowX = Nd4j.create(new double[]{1, 2, 3, 4});
val rowY = Nd4j.create(new double[]{1, 2, 3, 4}); val rowY = Nd4j.create(new double[]{1, 2, 3, 4});

View File

@ -34,9 +34,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
public class Level2Test extends BaseNd4jTestWithBackends { public class Level2Test extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemv1(Nd4jBackend backend) { public void testGemv1(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -50,9 +49,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
assertEquals(1853350f, array3.getFloat(3), 0.001f); assertEquals(1853350f, array3.getFloat(3), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemv2(Nd4jBackend backend) { public void testGemv2(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -66,9 +64,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
assertEquals(1853350f, array3.getFloat(3), 0.001f); assertEquals(1853350f, array3.getFloat(3), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemv3(Nd4jBackend backend) { public void testGemv3(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -82,9 +79,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
assertEquals(3353200f, array3.getFloat(3), 0.001f); assertEquals(3353200f, array3.getFloat(3), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemv4(Nd4jBackend backend) { public void testGemv4(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -98,9 +94,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
assertEquals(3353200f, array3.getFloat(3), 0.001f); assertEquals(3353200f, array3.getFloat(3), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemv5(Nd4jBackend backend) { public void testGemv5(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -116,9 +111,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
assertEquals(1853350f, array3.getFloat(3), 0.001f); assertEquals(1853350f, array3.getFloat(3), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemv6(Nd4jBackend backend) { public void testGemv6(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -134,9 +128,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
assertEquals(3353200f, array3.getFloat(3), 0.001f); assertEquals(3353200f, array3.getFloat(3), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemv7(Nd4jBackend backend) { public void testGemv7(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);

View File

@ -34,9 +34,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
public class Level3Test extends BaseNd4jTestWithBackends { public class Level3Test extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemm1(Nd4jBackend backend) { public void testGemm1(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100); INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -46,9 +45,8 @@ public class Level3Test extends BaseNd4jTestWithBackends {
assertEquals(338350f, array3.getFloat(0), 0.001f); assertEquals(338350f, array3.getFloat(0), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemm2(Nd4jBackend backend) { public void testGemm2(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100); INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -58,9 +56,8 @@ public class Level3Test extends BaseNd4jTestWithBackends {
assertEquals(338350f, array3.getFloat(0), 0.001f); assertEquals(338350f, array3.getFloat(0), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemm3(Nd4jBackend backend) { public void testGemm3(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
@ -78,9 +75,8 @@ public class Level3Test extends BaseNd4jTestWithBackends {
assertEquals(8328150.0f, array3.data().getFloat(21), 0.001f); assertEquals(8328150.0f, array3.data().getFloat(21), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemm4(Nd4jBackend backend) { public void testGemm4(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);
@ -97,9 +93,8 @@ public class Level3Test extends BaseNd4jTestWithBackends {
assertEquals(3853350f, array3.data().getFloat(21), 0.001f); assertEquals(3853350f, array3.data().getFloat(21), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemm5(Nd4jBackend backend) { public void testGemm5(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
@ -113,9 +108,8 @@ public class Level3Test extends BaseNd4jTestWithBackends {
assertEquals(3.3835E7f, array3.data().getFloat(99), 10f); assertEquals(3.3835E7f, array3.data().getFloat(99), 10f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemm6(Nd4jBackend backend) { public void testGemm6(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);

View File

@ -37,9 +37,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
public class ParamsTestsF extends BaseNd4jTestWithBackends { public class ParamsTestsF extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGemm (Nd4jBackend backend) { public void testGemm (Nd4jBackend backend) {
INDArray a = Nd4j.create(2, 2); INDArray a = Nd4j.create(2, 2);
INDArray b = Nd4j.create(2, 3); INDArray b = Nd4j.create(2, 3);

View File

@ -53,7 +53,7 @@ public class DataBufferTests extends BaseNd4jTestWithBackends {
@Test @Test
@Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657") @Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657")
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNoArgCreateBufferFromArray(Nd4jBackend backend) { public void testNoArgCreateBufferFromArray(Nd4jBackend backend) {
//Tests here: //Tests here:
@ -279,9 +279,8 @@ public class DataBufferTests extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCreateTypedBuffer(Nd4jBackend backend) { public void testCreateTypedBuffer(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
@ -351,9 +350,8 @@ public class DataBufferTests extends BaseNd4jTestWithBackends {
} }
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAsBytes(Nd4jBackend backend) { public void testAsBytes(Nd4jBackend backend) {
INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1); INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1);
@ -408,9 +406,8 @@ public class DataBufferTests extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEnsureLocation(){ public void testEnsureLocation(){
//https://github.com/eclipse/deeplearning4j/issues/8783 //https://github.com/eclipse/deeplearning4j/issues/8783
Nd4j.create(1); Nd4j.create(1);

View File

@ -72,7 +72,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
*/ */
@Test() @Test()
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation1(Nd4jBackend backend) { public void testBlasValidation1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> { assertThrows(ND4JIllegalStateException.class,() -> {
INDArray x = Nd4j.create(10); INDArray x = Nd4j.create(10);
@ -91,7 +91,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
*/ */
@Test() @Test()
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation2(Nd4jBackend backend) { public void testBlasValidation2(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> { assertThrows(RuntimeException.class,() -> {
INDArray a = Nd4j.create(100, 10); INDArray a = Nd4j.create(100, 10);
@ -111,7 +111,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
*/ */
@Test() @Test()
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation3(Nd4jBackend backend) { public void testBlasValidation3(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
INDArray x = Nd4j.create(100, 100); INDArray x = Nd4j.create(100, 100);

View File

@ -76,9 +76,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
DataTypeUtil.setDTypeForContext(initialType); DataTypeUtil.setDTypeForContext(initialType);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPointerCreation(Nd4jBackend backend) { public void testPointerCreation(Nd4jBackend backend) {
DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4); DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4);
Indexer indexer = DoubleIndexer.create(floatPointer); Indexer indexer = DoubleIndexer.create(floatPointer);
@ -87,9 +86,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
assertArrayEquals(other.asDouble(), buffer.asDouble(), 0.001); assertArrayEquals(other.asDouble(), buffer.asDouble(), 0.001);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetSet(Nd4jBackend backend) { public void testGetSet(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4}; double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
@ -100,9 +98,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSerialization2() throws Exception { public void testSerialization2() throws Exception {
INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10), INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10),
// Nd4j.ones(5,10).getRow(2) // Nd4j.ones(5,10).getRow(2)
@ -130,9 +127,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSerialization(@TempDir Path testDir) throws Exception { public void testSerialization(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
DataBuffer buf = Nd4j.createBuffer(5); DataBuffer buf = Nd4j.createBuffer(5);
@ -154,9 +150,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDup(Nd4jBackend backend) { public void testDup(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4}; double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
@ -166,9 +161,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPut(Nd4jBackend backend) { public void testPut(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4}; double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
@ -179,9 +173,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetRange(Nd4jBackend backend) { public void testGetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
double[] get = buffer.getDoublesAt(0, 3); double[] get = buffer.getDoublesAt(0, 3);
@ -196,9 +189,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetOffsetRange(Nd4jBackend backend) { public void testGetOffsetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
double[] get = buffer.getDoublesAt(1, 3); double[] get = buffer.getDoublesAt(1, 3);
@ -213,9 +205,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAssign(Nd4jBackend backend) { public void testAssign(Nd4jBackend backend) {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
DataBuffer one = Nd4j.createBuffer(new double[] {1}); DataBuffer one = Nd4j.createBuffer(new double[] {1});
@ -226,9 +217,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOffset(Nd4jBackend backend) { public void testOffset(Nd4jBackend backend) {
DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2); DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2);
assertEquals(2, create.length()); assertEquals(2, create.length());
@ -238,9 +228,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReallocation(Nd4jBackend backend) { public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4}); DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity()); assertEquals(4, buffer.capacity());
@ -250,9 +239,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
assertArrayEquals(old, Arrays.copyOf(buffer.asDouble(), 4), 1e-1); assertArrayEquals(old, Arrays.copyOf(buffer.asDouble(), 4), 1e-1);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReallocationWorkspace(Nd4jBackend backend) { public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
@ -269,9 +257,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAddressPointer(){ public void testAddressPointer(){
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){ if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
return; return;

View File

@ -72,9 +72,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPointerCreation(Nd4jBackend backend) { public void testPointerCreation(Nd4jBackend backend) {
FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4); FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4);
Indexer indexer = FloatIndexer.create(floatPointer); Indexer indexer = FloatIndexer.create(floatPointer);
@ -83,9 +82,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
assertArrayEquals(other.asFloat(), buffer.asFloat(), 0.001f); assertArrayEquals(other.asFloat(), buffer.asFloat(), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetSet(Nd4jBackend backend) { public void testGetSet(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4}; float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
@ -96,9 +94,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSerialization(@TempDir Path tempDir,Nd4jBackend backend) throws Exception { public void testSerialization(@TempDir Path tempDir,Nd4jBackend backend) throws Exception {
File dir = tempDir.toFile(); File dir = tempDir.toFile();
DataBuffer buf = Nd4j.createBuffer(5); DataBuffer buf = Nd4j.createBuffer(5);
@ -119,9 +116,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
assertArrayEquals(buf.asFloat(), buf2.asFloat(), 0.0001f); assertArrayEquals(buf.asFloat(), buf2.asFloat(), 0.0001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDup(Nd4jBackend backend) { public void testDup(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4}; float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
@ -129,9 +125,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
assertArrayEquals(d.asFloat(), d2.asFloat(), 0.001f); assertArrayEquals(d.asFloat(), d2.asFloat(), 0.001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testToNio(Nd4jBackend backend) { public void testToNio(Nd4jBackend backend) {
DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT); DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT);
assertEquals(4, buff.length()); assertEquals(4, buff.length());
@ -143,9 +138,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPut(Nd4jBackend backend) { public void testPut(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4}; float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
@ -156,9 +150,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetRange(Nd4jBackend backend) { public void testGetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(0, 3); float[] get = buffer.getFloatsAt(0, 3);
@ -174,9 +167,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetOffsetRange(Nd4jBackend backend) { public void testGetOffsetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(1, 3); float[] get = buffer.getFloatsAt(1, 3);
@ -193,9 +185,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAsBytes(Nd4jBackend backend) { public void testAsBytes(Nd4jBackend backend) {
INDArray arr = Nd4j.create(5); INDArray arr = Nd4j.create(5);
byte[] d = arr.data().asBytes(); byte[] d = arr.data().asBytes();
@ -205,9 +196,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAssign(Nd4jBackend backend) { public void testAssign(Nd4jBackend backend) {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
DataBuffer one = Nd4j.createBuffer(new double[] {1}); DataBuffer one = Nd4j.createBuffer(new double[] {1});
@ -217,9 +207,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
assertArrayEquals(assertion.asFloat(), blank.asFloat(), 0.0001f); assertArrayEquals(assertion.asFloat(), blank.asFloat(), 0.0001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReadWrite(Nd4jBackend backend) throws Exception { public void testReadWrite(Nd4jBackend backend) throws Exception {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -233,9 +222,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
assertArrayEquals(assertion.asFloat(), clone.asFloat(), 0.0001f); assertArrayEquals(assertion.asFloat(), clone.asFloat(), 0.0001f);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOffset(Nd4jBackend backend) { public void testOffset(Nd4jBackend backend) {
DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2); DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2);
assertEquals(2, create.length()); assertEquals(2, create.length());
@ -245,9 +233,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReallocation(Nd4jBackend backend) { public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4}); DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity()); assertEquals(4, buffer.capacity());
@ -258,9 +245,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
assertArrayEquals(old, newBuf, 1e-4F); assertArrayEquals(old, newBuf, 1e-4F);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReallocationWorkspace(Nd4jBackend backend) { public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
@ -277,9 +263,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
workspace.close(); workspace.close();
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAddressPointer(Nd4jBackend backend){ public void testAddressPointer(Nd4jBackend backend){
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){ if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
return; return;

View File

@ -42,9 +42,8 @@ import static org.junit.jupiter.api.Assertions.*;
public class IntDataBufferTests extends BaseNd4jTestWithBackends { public class IntDataBufferTests extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBasicSerde1() throws Exception { public void testBasicSerde1() throws Exception {
@ -82,9 +81,8 @@ public class IntDataBufferTests extends BaseNd4jTestWithBackends {
} }
*/ */
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReallocation(Nd4jBackend backend) { public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4}); DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity()); assertEquals(4, buffer.capacity());
@ -96,9 +94,8 @@ public class IntDataBufferTests extends BaseNd4jTestWithBackends {
assertArrayEquals(old, Arrays.copyOf(newContent, old.length)); assertArrayEquals(old, Arrays.copyOf(newContent, old.length));
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReallocationWorkspace(Nd4jBackend backend) { public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();

View File

@ -43,9 +43,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testINDArrayIndexingEqualToRank(Nd4jBackend backend) { public void testINDArrayIndexingEqualToRank(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
@ -60,9 +59,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testINDArrayIndexingLessThanRankSimple(Nd4jBackend backend) { public void testINDArrayIndexingLessThanRankSimple(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
@ -76,9 +74,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testINDArrayIndexingLessThanRankFourDimension(Nd4jBackend backend) { public void testINDArrayIndexingLessThanRankFourDimension(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
@ -91,9 +88,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutSimple(Nd4jBackend backend) { public void testPutSimple(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2); INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
@ -105,9 +101,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
assertEquals(vals,x); assertEquals(vals,x);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetScalar(Nd4jBackend backend) { public void testGetScalar(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(NDArrayIndex.point(1)); INDArray d = arr.get(NDArrayIndex.point(1));
@ -116,18 +111,16 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNewAxis(Nd4jBackend backend) { public void testNewAxis(Nd4jBackend backend) {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3}); INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1)); INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
// System.out.println(view); // System.out.println(view);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVectorIndexing(Nd4jBackend backend) { public void testVectorIndexing(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE);
int[] index = new int[] {5, 8, 9}; int[] index = new int[] {5, 8, 9};
@ -139,9 +132,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetRowsColumnsMatrix(Nd4jBackend backend) { public void testGetRowsColumnsMatrix(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6);
INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}}); INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}});
@ -159,9 +151,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSlicing(Nd4jBackend backend) { public void testSlicing(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE); INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14}); INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14});
@ -169,9 +160,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
assertEquals(slice1Assert, slice1Test); assertEquals(slice1Assert, slice1Test);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testArangeMul(Nd4jBackend backend) { public void testArangeMul(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE); INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = NDArrayIndex.interval(0, 2); INDArrayIndex index = NDArrayIndex.interval(0, 2);
@ -183,9 +173,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetIndicesVector(Nd4jBackend backend) { public void testGetIndicesVector(Nd4jBackend backend) {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1); INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3}); INDArray test = Nd4j.create(new double[] {2, 3});
@ -193,9 +182,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
assertEquals(test, result); assertEquals(test, result);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetIndicesVectorView(Nd4jBackend backend) { public void testGetIndicesVectorView(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5); INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5);
INDArray column = matrix.getColumn(0).reshape(1,5); INDArray column = matrix.getColumn(0).reshape(1,5);
@ -213,9 +201,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
assertEquals(exp2, result); assertEquals(exp2, result);
} }
@Test
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void test2dGetPoint(Nd4jBackend backend){ public void test2dGetPoint(Nd4jBackend backend){
INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4); INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4);
for( int i=0; i<3; i++ ){ for( int i=0; i<3; i++ ){

Some files were not shown because too many files have changed in this diff Show More