Migrate parameterized tests to junit 5

master
agibsonccc 2021-03-16 22:08:35 +09:00
parent 82bdcc21d2
commit 3c6014271e
260 changed files with 10787 additions and 6270 deletions

View File

@ -37,8 +37,10 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -47,13 +49,14 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.conf.ConvolutionMode.Same;
import static org.deeplearning4j.nn.conf.ConvolutionMode.Truncate;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class)
@DisplayName("Cnn Gradient Check Test")
class CNNGradientCheckTest extends BaseDL4JTest {
@ -71,15 +74,10 @@ class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
private CNN2DFormat format;
public CNNGradientCheckTest(CNN2DFormat format) {
this.format = format;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params() {
return CNN2DFormat.values();
public static Stream<Arguments> params() {
return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of);
}
@Override
@ -89,9 +87,11 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Gradient CNNMLN")
void testGradientCNNMLN() {
@ParameterizedTest
@MethodSource("#params")
public void testGradientCNNMLN(CNN2DFormat format) {
if (// Only test NCHW due to flat input format...
this.format != CNN2DFormat.NCHW)
format != CNN2DFormat.NCHW)
return;
// Parameterized test, testing combinations of:
// (a) activation function
@ -146,9 +146,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Gradient CNNL 1 L 2 MLN")
void testGradientCNNL1L2MLN() {
void testGradientCNNL1L2MLN(CNN2DFormat format) {
if (// Only test NCHW due to flat input format...
this.format != CNN2DFormat.NCHW)
format != CNN2DFormat.NCHW)
return;
// Parameterized test, testing combinations of:
// (a) activation function
@ -245,7 +245,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn With Space To Batch")
void testCnnWithSpaceToBatch() {
@ParameterizedTest
@MethodSource("#params")
public void testCnnWithSpaceToBatch(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 4;
int[] minibatchSizes = { 2, 4 };
@ -289,7 +291,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn With Upsampling")
void testCnnWithUpsampling() {
@ParameterizedTest
@MethodSource("#params")
void testCnnWithUpsampling(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 4;
int[] minibatchSizes = { 1, 3 };
@ -323,7 +327,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn With Subsampling")
void testCnnWithSubsampling() {
@ParameterizedTest
@MethodSource("#params")
void testCnnWithSubsampling(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 4;
int[] minibatchSizes = { 1, 3 };
@ -365,7 +371,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn With Subsampling V 2")
void testCnnWithSubsamplingV2() {
@ParameterizedTest
@MethodSource("#params")
void testCnnWithSubsamplingV2(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 4;
int[] minibatchSizes = { 1, 3 };
@ -403,7 +411,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Locally Connected 2 D")
void testCnnLocallyConnected2D() {
@ParameterizedTest
@MethodSource("#params")
void testCnnLocallyConnected2D(CNN2DFormat format) {
int nOut = 3;
int width = 5;
int height = 5;
@ -433,7 +443,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Multi Layer")
void testCnnMultiLayer() {
@ParameterizedTest
@MethodSource("#params")
void testCnnMultiLayer(CNN2DFormat format) {
int nOut = 2;
int[] minibatchSizes = { 1, 2, 5 };
int width = 5;
@ -473,7 +485,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Same Padding Mode")
void testCnnSamePaddingMode() {
@ParameterizedTest
@MethodSource("#params")
void testCnnSamePaddingMode(CNN2DFormat format) {
int nOut = 2;
int[] minibatchSizes = { 1, 3, 3, 2, 1, 2 };
// Same padding mode: insensitive to exact input size...
@ -507,7 +521,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Same Padding Mode Strided")
void testCnnSamePaddingModeStrided() {
@ParameterizedTest
@MethodSource("#params")
void testCnnSamePaddingModeStrided(CNN2DFormat format) {
int nOut = 2;
int[] minibatchSizes = { 1, 3 };
int width = 16;
@ -550,7 +566,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Zero Padding Layer")
void testCnnZeroPaddingLayer() {
@ParameterizedTest
@MethodSource("#params")
void testCnnZeroPaddingLayer(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 4;
int width = 6;
@ -596,7 +614,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Deconvolution 2 D")
void testDeconvolution2D() {
@ParameterizedTest
@MethodSource("#params")
void testDeconvolution2D(CNN2DFormat format) {
int nOut = 2;
int[] minibatchSizes = new int[] { 1, 3, 3, 1, 3 };
int[] kernelSizes = new int[] { 1, 1, 1, 3, 3 };
@ -641,7 +661,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Separable Conv 2 D")
void testSeparableConv2D() {
@ParameterizedTest
@MethodSource("#params")
void testSeparableConv2D(CNN2DFormat format) {
int nOut = 2;
int width = 6;
int height = 6;
@ -686,7 +708,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Dilated")
void testCnnDilated() {
@ParameterizedTest
@MethodSource("#params")
void testCnnDilated(CNN2DFormat format) {
int nOut = 2;
int minibatchSize = 2;
int width = 8;
@ -736,7 +760,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cropping 2 D Layer")
void testCropping2DLayer() {
@ParameterizedTest
@MethodSource("#params")
void testCropping2DLayer(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 2;
int width = 12;
@ -780,7 +806,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Depthwise Conv 2 D")
void testDepthwiseConv2D() {
@ParameterizedTest
@MethodSource("#params")
void testDepthwiseConv2D(CNN2DFormat format) {
int nIn = 3;
int depthMultiplier = 2;
int nOut = nIn * depthMultiplier;

View File

@ -39,8 +39,10 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -55,26 +57,22 @@ import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class YoloGradientCheckTests extends BaseDL4JTest {
static {
Nd4j.setDataType(DataType.DOUBLE);
}
private CNN2DFormat format;
public YoloGradientCheckTests(CNN2DFormat format){
this.format = format;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return CNN2DFormat.values();
}
public static Stream<Arguments> params() {
return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of);
}
@Override
public long getTimeoutMilliseconds() {
@ -82,7 +80,9 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
}
@Test
public void testYoloOutputLayer() {
@ParameterizedTest
@MethodSource("#params")
public void testYoloOutputLayer(CNN2DFormat format) {
int depthIn = 2;
int c = 3;
int b = 3;
@ -181,7 +181,9 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
@Test
public void yoloGradientCheckRealData(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("#params")
public void yoloGradientCheckRealData(@TempDir Path testDir,CNN2DFormat format) throws Exception {
Nd4j.getRandom().setSeed(12345);
InputStream is1 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2007_009346.jpg").getInputStream();
InputStream is2 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2007_009346.xml").getInputStream();

View File

@ -39,8 +39,10 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -48,24 +50,19 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class ConvDataFormatTests extends BaseDL4JTest {
private final DataType dataType;
public ConvDataFormatTests(DataType dataType){
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return new DataType[]{DataType.FLOAT, DataType.DOUBLE};
public static Stream<Arguments> params(){
return Arrays.asList(new DataType[]{DataType.FLOAT, DataType.DOUBLE}).stream().map(Arguments::of);
}
@Override
@ -74,7 +71,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testConv2d() {
@MethodSource("#params")
@ParameterizedTest
public void testConv2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -83,15 +82,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getConv2dNet(CNN2DFormat.NHWC, false, cm))
.net1(getConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -107,7 +106,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testSubsampling2d() {
@MethodSource("#params")
@ParameterizedTest
public void testSubsampling2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -116,15 +117,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm))
.net1(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -140,7 +141,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testDepthwiseConv2d() {
@MethodSource("#params")
@ParameterizedTest
public void testDepthwiseConv2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -149,15 +152,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm))
.net1(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -173,7 +176,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testSeparableConv2d() {
@MethodSource("#params")
@ParameterizedTest
public void testSeparableConv2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -182,15 +187,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm))
.net1(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -206,7 +211,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testDeconv2d() {
@MethodSource("#params")
@ParameterizedTest
public void testDeconv2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -215,15 +222,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm))
.net1(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -239,7 +246,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testLRN() {
@MethodSource("#params")
@ParameterizedTest
public void testLRN(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -248,15 +257,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLrnLayer(CNN2DFormat.NCHW, true, cm))
.net2(getLrnLayer(CNN2DFormat.NCHW, false, cm))
.net3(getLrnLayer(CNN2DFormat.NHWC, true, cm))
.net4(getLrnLayer(CNN2DFormat.NHWC, false, cm))
.net1(getLrnLayer(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getLrnLayer(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getLrnLayer(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getLrnLayer(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -272,7 +281,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testZeroPaddingLayer(){
@MethodSource("#params")
@ParameterizedTest
public void testZeroPaddingLayer(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -280,15 +291,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getZeroPaddingNet(CNN2DFormat.NCHW, true))
.net2(getZeroPaddingNet(CNN2DFormat.NCHW, false))
.net3(getZeroPaddingNet(CNN2DFormat.NHWC, true))
.net4(getZeroPaddingNet(CNN2DFormat.NHWC, false))
.net1(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, true))
.net2(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, false))
.net3(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, true))
.net4(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -303,7 +314,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testCropping2DLayer(){
@MethodSource("#params")
@ParameterizedTest
public void testCropping2DLayer(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -311,15 +324,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getCropping2dNet(CNN2DFormat.NCHW, true))
.net2(getCropping2dNet(CNN2DFormat.NCHW, false))
.net3(getCropping2dNet(CNN2DFormat.NHWC, true))
.net4(getCropping2dNet(CNN2DFormat.NHWC, false))
.net1(getCropping2dNet(dataType,CNN2DFormat.NCHW, true))
.net2(getCropping2dNet(dataType,CNN2DFormat.NCHW, false))
.net3(getCropping2dNet(dataType,CNN2DFormat.NHWC, true))
.net4(getCropping2dNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -334,7 +347,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testUpsampling2d(){
@MethodSource("#params")
@ParameterizedTest
public void testUpsampling2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -342,15 +357,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getUpsamplingNet(CNN2DFormat.NCHW, true))
.net2(getUpsamplingNet(CNN2DFormat.NCHW, false))
.net3(getUpsamplingNet(CNN2DFormat.NHWC, true))
.net4(getUpsamplingNet(CNN2DFormat.NHWC, false))
.net1(getUpsamplingNet(dataType,CNN2DFormat.NCHW, true))
.net2(getUpsamplingNet(dataType,CNN2DFormat.NCHW, false))
.net3(getUpsamplingNet(dataType,CNN2DFormat.NHWC, true))
.net4(getUpsamplingNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -365,7 +380,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testBatchNormNet(){
@MethodSource("#params")
@ParameterizedTest
public void testBatchNormNet(DataType dataType) {
try {
for(boolean useLogStd : new boolean[]{true, false}) {
for (boolean helpers : new boolean[]{false, true}) {
@ -374,15 +391,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std");
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true))
.net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false))
.net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true))
.net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false))
.net1(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, true))
.net2(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, false))
.net3(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, true))
.net4(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -398,7 +415,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testCnnLossLayer() {
@MethodSource("#params")
@ParameterizedTest
public void testCnnLossLayer(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -406,8 +425,8 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labelsNHWC = TestUtils.randomOneHot(dataType,2*6*6, 3);
labelsNHWC = labelsNHWC.reshape(2,6,6,3);
INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup();
@ -434,7 +453,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testSpaceToDepthNet(){
@MethodSource("#params")
@ParameterizedTest
public void testSpaceToDepthNet(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -442,15 +463,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true))
.net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false))
.net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true))
.net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false))
.net1(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, true))
.net2(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, false))
.net3(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, true))
.net4(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -465,7 +486,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testSpaceToBatchNet(){
@MethodSource("#params")
@ParameterizedTest
public void testSpaceToBatchNet(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -473,15 +496,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 16, 16);
INDArray labels = TestUtils.randomOneHot(8, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true))
.net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false))
.net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true))
.net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false))
.net1(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, true))
.net2(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, false))
.net3(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, true))
.net4(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -496,7 +519,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testLocallyConnected() {
@MethodSource("#params")
@ParameterizedTest
public void testLocallyConnected(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -505,15 +530,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm))
.net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm))
.net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm))
.net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm))
.net1(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -530,7 +555,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
@Test
public void testGlobalPooling() {
@MethodSource("#params")
@ParameterizedTest
public void testGlobalPooling(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (PoolingType pt : PoolingType.values()) {
@ -539,15 +566,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false))
.net1(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -562,9 +589,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new ConvolutionLayer.Builder()
return getNetWithLayer(dataType,new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -573,7 +600,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new ConvolutionLayer.Builder()
return getNetWithLayer(dataType,new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -583,16 +610,16 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getSubsampling2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new SubsamplingLayer.Builder()
return getNetWithLayer(dataType,new SubsamplingLayer.Builder()
.kernelSize(2, 2)
.stride(1, 1)
.dataFormat(format)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new SubsamplingLayer.Builder()
return getNetWithLayer(dataType,new SubsamplingLayer.Builder()
.kernelSize(2, 2)
.stride(1, 1)
.helperAllowFallback(false)
@ -600,9 +627,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getSeparableConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new SeparableConvolution2D.Builder()
return getNetWithLayer(dataType,new SeparableConvolution2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -611,7 +638,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new SeparableConvolution2D.Builder()
return getNetWithLayer(dataType,new SeparableConvolution2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -621,9 +648,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getDepthwiseConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder()
.depthMultiplier(2)
.kernelSize(3, 3)
.stride(2, 2)
@ -633,7 +660,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder()
.depthMultiplier(2)
.kernelSize(3, 3)
.stride(2, 2)
@ -644,59 +671,59 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getLrnLayer(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new LocalResponseNormalization.Builder()
return getNetWithLayer(dataType,new LocalResponseNormalization.Builder()
.dataFormat(format)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new LocalResponseNormalization.Builder()
return getNetWithLayer(dataType,new LocalResponseNormalization.Builder()
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getZeroPaddingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2)
return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(),
return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2).build(),
format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getCropping2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new Cropping2D.Builder(2,2)
return getNetWithLayer(dataType,new Cropping2D.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new Cropping2D.Builder(2,2)
return getNetWithLayer(dataType,new Cropping2D.Builder(2,2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getUpsamplingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new Upsampling2D.Builder(2)
return getNetWithLayer(dataType,new Upsampling2D.Builder(2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new Upsampling2D.Builder(2)
return getNetWithLayer(dataType,new Upsampling2D.Builder(2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getDeconv2DNet2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH)
.kernelSize(2,2)
.dataFormat(format)
.stride(2,2)
.build(), format, cm, null);
} else {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH)
.kernelSize(2,2)
.dataFormat(format)
@ -705,50 +732,50 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getBatchNormNet(DataType dataType,boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new BatchNormalization.Builder()
return getNetWithLayer(dataType,new BatchNormalization.Builder()
.useLogStd(logStdev)
.dataFormat(format)
.helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new BatchNormalization.Builder()
return getNetWithLayer(dataType,new BatchNormalization.Builder()
.useLogStd(logStdev)
.helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getSpaceToDepthNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToDepthLayer.Builder()
return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder()
.blocks(2)
.dataFormat(format)
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new SpaceToDepthLayer.Builder()
return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder()
.blocks(2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getSpaceToBatchNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToBatchLayer.Builder()
return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder()
.blocks(2, 2)
.dataFormat(format)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
} else {
return getNetWithLayer(new SpaceToBatchLayer.Builder()
return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder()
.blocks(2, 2)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
}
}
private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getLocallyConnectedNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new LocallyConnected2D.Builder()
return getNetWithLayer(dataType,new LocallyConnected2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -756,7 +783,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.nOut(3)
.build(), format, cm, null);
} else {
return getNetWithLayer(new LocallyConnected2D.Builder()
return getNetWithLayer(dataType,new LocallyConnected2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -765,9 +792,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
private MultiLayerNetwork getNetWithLayer(DataType dataType,Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.dataType(this.dataType)
.dataType(dataType)
.seed(12345)
.convolutionMode(cm)
.list()
@ -794,13 +821,13 @@ public class ConvDataFormatTests extends BaseDL4JTest {
return net;
}
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
private MultiLayerNetwork getGlobalPoolingNet(DataType dataType,CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt)
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt)
.build(), format, ConvolutionMode.Same, null);
}
}

View File

@ -45,8 +45,11 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -61,30 +64,29 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
@RunWith(Parameterized.class)
@DisplayName("Bidirectional Test")
class BidirectionalTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public BidirectionalTest(RNNFormat rnnDataFormat) {
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params() {
return RNNFormat.values();
public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Test
@DisplayName("Compare Implementations")
void compareImplementations() {
@ParameterizedTest
@MethodSource("#params")
void compareImplementations(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm);
// Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params
@ -147,9 +149,11 @@ class BidirectionalTest extends BaseDL4JTest {
}
}
@Test
@DisplayName("Compare Implementations Comp Graph")
void compareImplementationsCompGraph() {
@Test
@ParameterizedTest
@MethodSource("#params")
void compareImplementationsCompGraph(RNNFormat rnnFormat) {
// for(WorkspaceMode wsm : WorkspaceMode.values()) {
for (WorkspaceMode wsm : new WorkspaceMode[] { WorkspaceMode.NONE, WorkspaceMode.ENABLED }) {
log.info("*** Starting workspace mode: " + wsm);
@ -187,8 +191,8 @@ class BidirectionalTest extends BaseDL4JTest {
Gradient g2 = net2.gradient();
assertEquals(g1.gradient(), g2.gradient());
// Ensure updates are equal:
ComputationGraphUpdater u1 = (ComputationGraphUpdater) net1.getUpdater();
ComputationGraphUpdater u2 = (ComputationGraphUpdater) net2.getUpdater();
ComputationGraphUpdater u1 = net1.getUpdater();
ComputationGraphUpdater u2 = net2.getUpdater();
assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray());
u1.update(g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
u2.update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
@ -205,7 +209,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test
@DisplayName("Test Serialization")
void testSerialization() throws Exception {
@ParameterizedTest
@MethodSource("#params")
void testSerialization(RNNFormat rnnDataFormat) throws Exception {
for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345);
@ -242,7 +248,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test
@DisplayName("Test Serialization Comp Graph")
void testSerializationCompGraph() throws Exception {
@ParameterizedTest
@MethodSource("#params")
void testSerializationCompGraph(RNNFormat rnnDataFormat) throws Exception {
for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345);
@ -277,7 +285,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test
@DisplayName("Test Simple Bidirectional")
void testSimpleBidirectional() {
@ParameterizedTest
@MethodSource("#params")
public void testSimpleBidirectional(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345);
@ -362,7 +372,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test
@DisplayName("Test Simple Bidirectional Comp Graph")
void testSimpleBidirectionalCompGraph() {
@ParameterizedTest
@MethodSource("#params")
void testSimpleBidirectionalCompGraph(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345);

View File

@ -19,7 +19,6 @@
*/
package org.deeplearning4j.nn.layers.recurrent;
import junit.framework.TestCase;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -34,9 +33,12 @@ import org.deeplearning4j.nn.params.GravesBidirectionalLSTMParamInitializer;
import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -44,31 +46,29 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.primitives.Pair;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class)
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Graves Bidirectional LSTM Test")
class GravesBidirectionalLSTMTest extends BaseDL4JTest {
private double score = 0.0;
private RNNFormat rnnDataFormat;
public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat) {
this.rnnDataFormat = rnnDataFormat;
public static Stream<Arguments> params(){
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters
public static Object[] params() {
return RNNFormat.values();
}
@Test
@DisplayName("Test Bidirectional LSTM Graves Forward Basic")
void testBidirectionalLSTMGravesForwardBasic() {
@MethodSource("#params")
@ParameterizedTest
void testBidirectionalLSTMGravesForwardBasic(RNNFormat rnnDataFormat) {
// Very basic test of forward prop. of LSTM layer with a time series.
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
int nIn = 13;
@ -110,19 +110,21 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test
@DisplayName("Test Bidirectional LSTM Graves Backward Basic")
void testBidirectionalLSTMGravesBackwardBasic() {
@MethodSource("#params")
@ParameterizedTest
void testBidirectionalLSTMGravesBackwardBasic(RNNFormat rnnDataFormat) {
// Very basic test of backprop for mini-batch + time series
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
testGravesBackwardBasicHelper(13, 3, 17, 10, 7);
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 7);
// Edge case: miniBatchSize = 1
testGravesBackwardBasicHelper(13, 3, 17, 1, 7);
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 7);
// Edge case: timeSeriesLength = 1
testGravesBackwardBasicHelper(13, 3, 17, 10, 1);
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 1);
// Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
testGravesBackwardBasicHelper(13, 3, 17, 1, 1);
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 1);
}
private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) {
private void testGravesBackwardBasicHelper(RNNFormat rnnDataFormat,int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) {
INDArray inputData = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.ones(miniBatchSize, nIn, timeSeriesLength) : Nd4j.ones(miniBatchSize, timeSeriesLength, nIn);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build();
long numParams = conf.getLayer().initializer().numParams(conf);
@ -204,7 +206,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test
@DisplayName("Test Get Set Parmas")
void testGetSetParmas() {
@MethodSource("#params")
@ParameterizedTest
void testGetSetParmas(RNNFormat rnnDataFormat) {
final int nIn = 2;
final int layerSize = 3;
final int miniBatchSize = 2;
@ -224,7 +228,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test
@DisplayName("Test Simple Forwards And Backwards Activation")
void testSimpleForwardsAndBackwardsActivation() {
@MethodSource("#params")
@ParameterizedTest
void testSimpleForwardsAndBackwardsActivation(RNNFormat rnnDataFormat) {
final int nIn = 2;
final int layerSize = 3;
final int miniBatchSize = 1;
@ -342,7 +348,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test
@DisplayName("Test Gate Activation Fns Sanity Check")
void testGateActivationFnsSanityCheck() {
@MethodSource("#params")
@ParameterizedTest
void testGateActivationFnsSanityCheck(RNNFormat rnnDataFormat) {
for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);

View File

@ -30,36 +30,35 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.Arrays;
import java.util.Collections;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class)
@DisplayName("Mask Zero Layer Test")
class MaskZeroLayerTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public MaskZeroLayerTest(RNNFormat rnnDataFormat) {
this.rnnDataFormat = rnnDataFormat;
public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters
public static Object[] params() {
return RNNFormat.values();
}
@Test
@DisplayName("Activate")
void activate() {
@Test
@ParameterizedTest
@MethodSource("#params")
void activate(RNNFormat rnnDataFormat) {
// GIVEN two examples where some of the timesteps are zero.
INDArray ex1 = Nd4j.create(new double[][] { new double[] { 0, 3, 5 }, new double[] { 0, 0, 2 } });
INDArray ex2 = Nd4j.create(new double[][] { new double[] { 0, 0, 2 }, new double[] { 0, 0, 2 } });
@ -95,9 +94,12 @@ class MaskZeroLayerTest extends BaseDL4JTest {
assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6);
}
@Test
@DisplayName("Test Serialization")
void testSerialization() {
@Test
@ParameterizedTest
@MethodSource("#params")
void testSerialization(RNNFormat rnnDataFormat) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder().setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

View File

@ -40,8 +40,10 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -51,30 +53,31 @@ import org.nd4j.common.primitives.Pair;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@AllArgsConstructor
public class RnnDataFormatTests extends BaseDL4JTest {
private boolean helpers;
private boolean lastTimeStep;
private boolean maskZeros;
@Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}")
public static List params(){
public static Stream<Arguments> params() {
List<Object[]> ret = new ArrayList<>();
for (boolean helpers: new boolean[]{true, false})
for (boolean lastTimeStep: new boolean[]{true, false})
for (boolean maskZero: new boolean[]{true, false})
ret.add(new Object[]{helpers, lastTimeStep, maskZero});
return ret;
return ret.stream().map(Arguments::of);
}
@Test
public void testSimpleRnn() {
@MethodSource("#params")
@ParameterizedTest
public void testSimpleRnn(boolean helpers,
boolean lastTimeStep,
boolean maskZeros
) {
try {
Nd4j.getRandom().setSeed(12345);
@ -107,7 +110,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
}
@Test
public void testLSTM() {
@ParameterizedTest
@MethodSource("#params")
public void testLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try {
Nd4j.getRandom().setSeed(12345);
@ -141,7 +148,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
@Test
public void testGraveLSTM() {
@MethodSource("#params")
@ParameterizedTest
public void testGraveLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try {
Nd4j.getRandom().setSeed(12345);
@ -175,7 +186,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
@Test
public void testGraveBiLSTM() {
@MethodSource("#params")
@ParameterizedTest
public void testGraveBiLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try {
Nd4j.getRandom().setSeed(12345);

View File

@ -34,14 +34,20 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.learning.config.AdaGrad;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
import static org.deeplearning4j.nn.weights.WeightInit.XAVIER_UNIFORM;
import static org.junit.jupiter.api.Assertions.*;
@ -50,20 +56,16 @@ import static org.nd4j.linalg.activations.Activation.TANH;
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
@RunWith(Parameterized.class)
public class TestLastTimeStepLayer extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestLastTimeStepLayer(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters(name="{0}")
public static Object[] params(){
return RNNFormat.values();
public static Stream<Arguments> params(){
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Test
public void testLastTimeStepVertex() {
@ParameterizedTest
@MethodSource("#params")
public void testLastTimeStepVertex(RNNFormat rnnDataFormat) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
@ -126,7 +128,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
}
@Test
public void testMaskingAndAllMasked(){
@ParameterizedTest
@MethodSource("#params")
public void testMaskingAndAllMasked(RNNFormat rnnDataFormat) {
ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder()
.optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT)
.weightInit(XAVIER_UNIFORM)

View File

@ -36,8 +36,11 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -49,25 +52,23 @@ import org.nd4j.common.primitives.Pair;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class TestRnnLayers extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestRnnLayers(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
public static Stream<Arguments> params(){
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Test
public void testTimeStepIs3Dimensional() {
@ParameterizedTest
@MethodSource("#params")
public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat) {
int nIn = 12;
int nOut = 3;
@ -117,7 +118,9 @@ public class TestRnnLayers extends BaseDL4JTest {
}
@Test
public void testDropoutRecurrentLayers(){
@ParameterizedTest
@MethodSource("#params")
public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat){
Nd4j.getRandom().setSeed(12345);
String[] layerTypes = new String[]{"graves", "lstm", "simple"};
@ -215,7 +218,9 @@ public class TestRnnLayers extends BaseDL4JTest {
}
@Test
public void testMismatchedInputLabelLength(){
@ParameterizedTest
@MethodSource("#params")
public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat){
for( int i = 0; i < 2; i++) {

View File

@ -29,8 +29,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,25 +40,25 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
@RunWith(Parameterized.class)
public class TestSimpleRnn extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestSimpleRnn(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Test
public void testSimpleRnn(){
@ParameterizedTest
@MethodSource("#params")
public void testSimpleRnn(RNNFormat rnnDataFormat) {
Nd4j.getRandom().setSeed(12345);
int m = 3;
@ -125,7 +127,9 @@ public class TestSimpleRnn extends BaseDL4JTest {
}
@Test
public void testBiasInit(){
@ParameterizedTest
@MethodSource("#params")
public void testBiasInit(RNNFormat rnnDataFormat) {
Nd4j.getRandom().setSeed(12345);
int nIn = 5;
int layerSize = 6;

View File

@ -37,8 +37,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -47,22 +49,22 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestTimeDistributed extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestTimeDistributed(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
public static Stream<Arguments> params(){
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Test
public void testTimeDistributed(){
@ParameterizedTest
@MethodSource("#params")
public void testTimeDistributed(RNNFormat rnnDataFormat){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -133,7 +135,9 @@ public class TestTimeDistributed extends BaseDL4JTest {
@Test
public void testTimeDistributedDense(){
@MethodSource("#params")
@ParameterizedTest
public void testTimeDistributedDense(RNNFormat rnnDataFormat){
for( int rnnType = 0; rnnType < 3; rnnType++ ) {
for( int ffType = 0; ffType < 3; ffType++ ) {

View File

@ -39,8 +39,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -145,92 +145,6 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-plugin</artifactId>
<version>1.4.30-M1</version>
<configuration>
<args>
<arg>-Xjsr305=strict</arg>
</args>
<compilerPlugins>
<plugin>spring</plugin>
<plugin>jpa</plugin>
</compilerPlugins>
</configuration>
<dependencies>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-allopen</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-noarg</artifactId>
<version>${kotlin.version}</version>
</dependency>
</dependencies>
<executions>
<execution>
<id>compile</id>
<goals> <goal>compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/main/stubs</sourceDir>
<sourceDir>${project.basedir}/src/main/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/main/java</sourceDir>
<sourceDir>${project.basedir}/src/main/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
<execution>
<id>test-compile</id>
<goals> <goal>test-compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/test/stubs</sourceDir>
<sourceDir>${project.basedir}/src/test/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/test/java</sourceDir>
<sourceDir>${project.basedir}/src/test/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
</executions>
</plugin>
<!-- https://kotlinlang.org/docs/reference/using-maven.html -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.5.1</version>
<executions>
<!-- Replacing default-compile as it is treated specially by maven -->
<execution>
<id>default-compile</id>
<phase>none</phase>
</execution>
<!-- Replacing default-testCompile as it is treated specially by maven -->
<execution>
<id>default-testCompile</id>
<phase>none</phase>
</execution>
<execution>
<id>java-compile</id>
<phase>compile</phase>
<goals> <goal>compile</goal> </goals>
</execution>
<execution>
<id>java-test-compile</id>
<phase>test-compile</phase>
<goals> <goal>testCompile</goal> </goals>
</execution>
</executions>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins>
</build>
@ -244,7 +158,10 @@
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib-jdk8</artifactId>
@ -261,11 +178,14 @@
<groupId>org.nd4j</groupId>
<artifactId>samediff-import-tensorflow</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>samediff-import-onnx</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>

View File

@ -22,11 +22,6 @@ package org.nd4j;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.imports.tfgraphs.TFGraphTestAllLibnd4j;
import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
import org.nd4j.imports.tfgraphs.TFGraphTestList;
import org.nd4j.imports.tfgraphs.TFGraphTestZooModels;
import org.nd4j.imports.listeners.ImportModelDebugger;
import java.util.*;
@Slf4j
@ -36,11 +31,6 @@ public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>(Arrays.asList(
TFGraphTestAllSameDiff.class,
TFGraphTestAllLibnd4j.class,
TFGraphTestList.class,
TFGraphTestZooModels.class,
ImportModelDebugger.class //Run manually only, otherwise ignored
));
}

View File

@ -20,19 +20,16 @@
package org.nd4j;
import org.bytedeco.javacpp.Loader;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.nd4j.autodiff.opvalidation.*;
import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
//import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.function.Function;
import static org.junit.Assume.assumeFalse;
@ -49,7 +46,7 @@ import static org.junit.Assume.assumeFalse;
TransformOpValidation.class,
//TF import tests
TFGraphTestAllSameDiff.class
//TFGraphTestAllSameDiff.class
//TFGraphTestAllLibnd4j.class
})
//IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test"

View File

@ -27,10 +27,12 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.ImportClassMapping;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.NoOp;
import org.nd4j.linalg.api.ops.compat.CompatSparseToDense;
@ -122,13 +124,11 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Disabled("No longer relevant after model import rewrite.")
public class TestOpMapping extends BaseNd4jTest {
public class TestOpMapping extends BaseNd4jTestWithBackends {
Set<Class<? extends DifferentialFunction>> subTypes;
public TestOpMapping(Nd4jBackend b){
super(b);
public TestOpMapping() {
Reflections reflections = new Reflections("org.nd4j");
subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
}
@ -146,6 +146,8 @@ public class TestOpMapping extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpMappingCoverage() throws Exception {
Map<String, DifferentialFunction> opNameMapping = ImportClassMapping.getOpNameMapping();
Map<String, DifferentialFunction> tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions();
@ -196,7 +198,9 @@ public class TestOpMapping extends BaseNd4jTest {
}
@Test
public void testOpsInNamespace() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpsInNamespace(Nd4jBackend backend) throws Exception {
//Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't
// want to add to a namespace for some reason)
//Note that we ignore "*Bp", "*Gradient", "*Derivative" etc ops
@ -354,8 +358,11 @@ public class TestOpMapping extends BaseNd4jTest {
s.add(Assign.class);
}
@Test @Disabled
public void generateOpClassList() throws Exception{
@Test
@Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void generateOpClassList(Nd4jBackend backend) throws Exception{
Reflections reflections = new Reflections("org.nd4j");
Set<Class<? extends DifferentialFunction>> subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
@ -366,12 +373,7 @@ public class TestOpMapping extends BaseNd4jTest {
l.add(c);
}
Collections.sort(l, new Comparator<Class<?>>() {
@Override
public int compare(Class<?> o1, Class<?> o2) {
return o1.getName().compareTo(o2.getName());
}
});
Collections.sort(l, Comparator.comparing(Class::getName));
for(Class<?> c : l){
System.out.println(c.getName() + ".class,");

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable;
@ -31,7 +33,7 @@ import org.nd4j.autodiff.samediff.internal.FrameIter;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.memory.NoOpMemoryMgr;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -46,11 +48,7 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
public class TestSessions extends BaseNd4jTest {
public TestSessions(Nd4jBackend b){
super(b);
}
public class TestSessions extends BaseNd4jTestWithBackends {
@Override
public char ordering() {
@ -58,7 +56,9 @@ public class TestSessions extends BaseNd4jTest {
}
@Test
public void testInferenceSessionBasic(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInferenceSessionBasic(Nd4jBackend backend) {
//So far: trivial test to check execution order
SameDiff sd = SameDiff.create();
@ -90,7 +90,9 @@ public class TestSessions extends BaseNd4jTest {
@Test
public void testInferenceSessionBasic2(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInferenceSessionBasic2(Nd4jBackend backend) {
//So far: trivial test to check execution order
SameDiff sd = SameDiff.create();
@ -126,7 +128,9 @@ public class TestSessions extends BaseNd4jTest {
}
@Test
public void testMergeSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeSimple(Nd4jBackend backend) {
//This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available...
SameDiff sd = SameDiff.create();
@ -162,7 +166,9 @@ public class TestSessions extends BaseNd4jTest {
@Test
public void testSwitchSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSwitchSimple(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3,3);

View File

@ -21,10 +21,12 @@
package org.nd4j.autodiff.internal;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.DependencyList;
import org.nd4j.autodiff.samediff.internal.DependencyTracker;
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,11 +37,8 @@ import java.util.Collections;
import static junit.framework.TestCase.assertNotNull;
import static org.junit.jupiter.api.Assertions.*;
public class TestDependencyTracker extends BaseNd4jTest {
public class TestDependencyTracker extends BaseNd4jTestWithBackends {
public TestDependencyTracker(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -47,7 +46,9 @@ public class TestDependencyTracker extends BaseNd4jTest {
}
@Test
public void testSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>();
@ -94,7 +95,9 @@ public class TestDependencyTracker extends BaseNd4jTest {
}
@Test
public void testSatisfiedBeforeAdd(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSatisfiedBeforeAdd(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>();
//Check different order of adding dependencies: i.e., mark X as satisfied, then add x -> y dependency
@ -133,7 +136,9 @@ public class TestDependencyTracker extends BaseNd4jTest {
}
@Test
public void testMarkUnsatisfied(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMarkUnsatisfied(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>();
dt.addDependency("y", "x");
@ -165,6 +170,8 @@ public class TestDependencyTracker extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIdentityDependencyTracker(){
IdentityDependencyTracker<INDArray, String> dt = new IdentityDependencyTracker<>();
assertTrue(dt.isEmpty());

View File

@ -21,6 +21,8 @@
package org.nd4j.autodiff.opvalidation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.GradCheckUtil;
@ -38,12 +40,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class ActivationGradChecks extends BaseOpValidation {
public ActivationGradChecks(Nd4jBackend backend) {
super(backend);
}
@Test
public void testActivationGradientCheck1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testActivationGradientCheck1(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4));
@ -61,7 +62,9 @@ public class ActivationGradChecks extends BaseOpValidation {
}
@Test
public void testActivationGradientCheck2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testActivationGradientCheck2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4);

View File

@ -21,18 +21,14 @@
package org.nd4j.autodiff.opvalidation;
import org.junit.jupiter.api.BeforeEach;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
public abstract class BaseOpValidation extends BaseNd4jTest {
public abstract class BaseOpValidation extends BaseNd4jTestWithBackends {
private DataType initialType;
private DataType initialType = Nd4j.dataType();
public BaseOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {

View File

@ -27,6 +27,8 @@ import java.util.List;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.OpValidation;
@ -65,9 +67,6 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class LayerOpValidation extends BaseOpValidation {
public LayerOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override
public long getTimeoutMilliseconds() {
@ -75,7 +74,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testXwPlusB() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testXwPlusB(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create();
@ -109,7 +110,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testReluLayer() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReluLayer(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create();
@ -137,7 +140,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testBiasAdd() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create();
@ -161,7 +166,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testConv2d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2d(Nd4jBackend backend) {
//avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling
//Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d
@ -301,7 +308,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLrn2d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLrn2d(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};
@ -342,7 +351,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testIm2Col() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIm2Col(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873
Nd4j.getRandom().setSeed(12345);
@ -381,7 +392,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testOutputShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOutputShape(Nd4jBackend backend) {
long[] inSize = {1, 8, 8, 3};
SameDiff sd = SameDiff.create();
@ -431,7 +444,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testAvgPool() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPool(Nd4jBackend backend) {
long[] inSize = {1, 8, 8, 3}; //NHWC
Pooling2DConfig conf = Pooling2DConfig.builder()
@ -474,7 +489,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testConv3d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3d(Nd4jBackend backend) {
//Pooling3d, Conv3D, batch norm
Nd4j.getRandom().setSeed(12345);
@ -576,7 +593,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testDepthWiseConv2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthWiseConv2dBasic(Nd4jBackend backend) {
int nIn = 3;
int depthWise = 4;
int kH = 2;
@ -615,7 +634,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testSeparableConv2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSeparableConv2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nIn = 2;
int nOut = 3;
@ -671,7 +692,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testDeconv2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeconv2dBasic(Nd4jBackend backend) {
int nIn = 2;
int nOut = 3;
int kH = 2;
@ -715,7 +738,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testConv2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2dBasic(Nd4jBackend backend) {
int nIn = 3;
int nOut = 4;
int kH = 2;
@ -756,7 +781,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testMaxPoolingArgMax() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPoolingArgMax(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nIn = 3;
int kH = 2;
@ -785,7 +812,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testMaxPooling2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPooling2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nIn = 3;
int kH = 2;
@ -843,7 +872,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testAvgPooling2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPooling2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nIn = 3;
int kH = 2;
@ -892,7 +923,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testAvgPooling3dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPooling3dBasic(Nd4jBackend backend) {
int nIn = 3;
int kH = 2;
int kW = 2;
@ -929,7 +962,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testMaxPooling3dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPooling3dBasic(Nd4jBackend backend) {
int nIn = 3;
int kH = 2;
int kW = 2;
@ -967,7 +1002,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testConv1dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dBasic(Nd4jBackend backend) {
int nIn = 3;
int nOut = 4;
int k = 2;
@ -1002,7 +1039,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testConv1dCausal() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dCausal(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nIn = 3;
int nOut = 4;
@ -1051,7 +1090,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testConv1dForward() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dForward(Nd4jBackend backend) {
int nIn = 2;
int nOut = 1;
int kernel = 3;
@ -1094,7 +1135,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testConv3dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3dBasic(Nd4jBackend backend) {
int nIn = 3;
int nOut = 4;
int kH = 2;
@ -1140,7 +1183,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testDeConv3dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv3dBasic(Nd4jBackend backend) {
int nIn = 4;
int nOut = 3;
int kH = 2;
@ -1185,7 +1230,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLayerNorm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNorm(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1210,7 +1257,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLayerNorm4d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNorm4d(Nd4jBackend backend) {
int mb = 3;
int ch = 4;
for (boolean nchw : new boolean[]{true, false}) {
@ -1242,7 +1291,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testLayerNormOP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormOP(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1258,7 +1309,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLayerNormNoBias() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormNoBias(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1281,7 +1334,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLayerNormOPNoBias() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormOPNoBias(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1296,7 +1351,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLayerNormNoDeviation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormNoDeviation(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
for (int i = 0; i < 4; i++) {
random.putScalar(1, i, 7);
@ -1326,7 +1383,7 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test()
public void exceptionThrown_WhenConv1DConfigInvalid() {
public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
int nIn = 3;
int nOut = 4;
@ -1355,7 +1412,7 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test()
public void exceptionThrown_WhenConv2DConfigInvalid() {
public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345);
@ -1378,7 +1435,7 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test()
public void exceptionThrown_WhenConf3DInvalid() {
public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345);
@ -1411,7 +1468,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLayerNormMixedOrders() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormMixedOrders(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
@ -1458,7 +1517,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testBiasAdd_nchw_nhwc() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd_nchw_nhwc(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
for (boolean nchw : new boolean[]{true, false}) {
@ -1489,6 +1550,8 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthwiseConv2D(){
int bS = 10;
@ -1527,7 +1590,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void LSTMLayerTestCase1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase1(Nd4jBackend backend) {
int bS = 5;
int nIn = 3;
@ -1602,7 +1667,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void LSTMLayerTestCase2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase2(Nd4jBackend backend) {
int bS = 5;
int nIn = 3;
int numUnits = 7;
@ -1660,7 +1727,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void LSTMLayerTestCase3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase3(Nd4jBackend backend) {
int bS = 5;
int nIn = 3;
int numUnits = 7;
@ -1721,7 +1790,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void GRUTestCase() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void GRUTestCase(Nd4jBackend backend) {
int bS = 5;
int nIn = 4;
int nOut = 6;

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
@ -43,9 +45,7 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class LossOpValidation extends BaseOpValidation {
public LossOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override
public long getTimeoutMilliseconds() {
@ -56,7 +56,9 @@ public class LossOpValidation extends BaseOpValidation {
public static final Set<String> NO_BP_YET = new HashSet<>();
@Test
public void testLoss2d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoss2d(Nd4jBackend backend) {
final List<String> oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax");
Nd4j.getRandom().setSeed(12345);
@ -368,6 +370,8 @@ public class LossOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCosineDistance(){
INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}});
INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}});
@ -386,6 +390,8 @@ public class LossOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testL2Loss(){
for( int rank=0; rank<=3; rank++ ){
@ -428,7 +434,9 @@ public class LossOpValidation extends BaseOpValidation {
}
@Test
public void testNonZeroResult() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNonZeroResult(Nd4jBackend backend) {
INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5);
INDArray w = Nd4j.scalar(1.0);
INDArray label = Nd4j.rand(DataType.DOUBLE, 10, 5);
@ -486,6 +494,8 @@ public class LossOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void TestStdLossMixedDataType(){
// Default Data Type in this test suite is Double.
// This test used to throw an Exception that we have mixed data types.

View File

@ -23,6 +23,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -78,13 +80,12 @@ import static org.junit.Assume.assumeNotNull;
@Slf4j
public class MiscOpValidation extends BaseOpValidation {
public MiscOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test
public void testGradientAutoBroadcast1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast1(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
@ -171,7 +172,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testGradientAutoBroadcast2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -260,7 +263,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testGradientAutoBroadcast3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast3(Nd4jBackend backend) {
//These tests: output size > input sizes
Nd4j.getRandom().setSeed(12345);
@ -368,7 +373,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testScatterOpGradients() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterOpGradients(Nd4jBackend backend) {
List<String> failed = new ArrayList<>();
for (int i = 0; i < 7; i++) {
@ -470,6 +477,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterUpdate(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3);
INDArray updates = Nd4j.create(new float[][]{
@ -491,7 +500,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testGatherGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -542,6 +553,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrace(){
//TODO need to work out how to handle shape_op for scalars...
//OpValidationSuite.ignoreFailing();
@ -567,7 +580,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
public void testTensorGradTensorMmul() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorGradTensorMmul(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing();
Nd4j.getRandom().setSeed(12345);
@ -589,7 +604,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testMulGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMulGradient(Nd4jBackend backend) {
INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
@ -654,22 +671,21 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
public void testMmulGradientManual() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulGradientManual(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
Map<String, INDArray> inputs = new HashMap<>();
inputs.put("x", sumInput);
inputs.put("y", sumInput.dup());
sameDiff.defineFunction("mmulGradient", new SameDiffFunctionDefinition() {
@Override
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
SDVariable input = sameDiff.var("x", inputs.get("x"));
SDVariable input2 = sameDiff.var("y", inputs.get("y"));
SDVariable exp = sameDiff.mmul(input, input2);
SDVariable sum = sameDiff.sum(exp, Integer.MAX_VALUE);
sameDiff.defineFunction("mmulGradient", (sameDiff1, inputs1, variableInputs) -> {
SDVariable input = sameDiff1.var("x", inputs1.get("x"));
SDVariable input2 = sameDiff1.var("y", inputs1.get("y"));
SDVariable exp = sameDiff1.mmul(input, input2);
SDVariable sum = sameDiff1.sum(exp, Integer.MAX_VALUE);
return new SDVariable[]{sum};
}
}, inputs);
@ -698,6 +714,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulGradients(){
int[] aShape = new int[]{2,3};
int[] bShape = new int[]{3,4};
@ -749,7 +767,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testBatchMmulBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBatchMmulBasic(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873
int M = 5;
int N = 3;
@ -774,7 +794,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
public void testMmulWithTranspose() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulWithTranspose(Nd4jBackend backend) {
//Here: [x,3]^T * [x,4] = [3,4]
@ -811,6 +833,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulOutputSizeCalculation(){
//[3,2] x [2,4] with result transpose: output shape [4,3]
INDArray a = Nd4j.create(3,2);
@ -843,6 +867,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFillOp(){
INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT);
@ -857,6 +883,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm(){
//Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -889,6 +917,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm2(){
//Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -932,6 +962,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm1(){
//Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -972,6 +1004,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm0(){
//Expected: if array.norm2(0) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -1001,6 +1035,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumSum(){
List<String> failing = new ArrayList<>();
@ -1066,6 +1102,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumProd(){
List<String> failing = new ArrayList<>();
@ -1134,6 +1172,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot1(){
List<String> failed = new ArrayList<>();
@ -1164,6 +1204,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHotOp(){
//https://www.tensorflow.org/api_docs/python/tf/one_hot
//https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp
@ -1178,7 +1220,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testOneHot2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot2(Nd4jBackend backend) {
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
@ -1198,7 +1242,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testOneHot4() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot4(Nd4jBackend backend) {
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
@ -1218,7 +1264,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testOneHot3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot3(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6872
//https://www.tensorflow.org/api_docs/python/tf/one_hot
@ -1253,6 +1301,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLinspace(){
SameDiff sd = SameDiff.create();
SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10);
@ -1266,6 +1316,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLinspace2(){
OpValidationSuite.ignoreFailing(); //TODO 2019/01/18
SameDiff sd = SameDiff.create();
@ -1280,7 +1332,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
public void testShapeFn() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapeFn(Nd4jBackend backend) {
INDArray in = Nd4j.create(new long[]{1, 2});
@ -1294,7 +1348,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testShapeFn2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapeFn2(Nd4jBackend backend) {
INDArray i = Nd4j.create(1,3);
@ -1307,6 +1363,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeRank1(){
SameDiff sd = SameDiff.create();
SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5));
@ -1325,7 +1383,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testDiagPart() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagPart(Nd4jBackend backend) {
INDArray i = Nd4j.create(5,5);
SameDiff sd = SameDiff.create();
@ -1337,7 +1397,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testDiagShapeFn() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagShapeFn(Nd4jBackend backend) {
INDArray i = Nd4j.create(5,5);
CustomOp op = new DiagPart(i, null);
@ -1350,6 +1412,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZerosOnesLike(){
Nd4j.getRandom().setSeed(12345);
@ -1392,6 +1456,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZerosLikeOp(){
INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0);
@ -1407,6 +1473,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrix(){
DataType dt = DataType.DOUBLE;
@ -1443,6 +1511,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIsNonDecreasingIsStrictlyIncr(){
List<long[]> shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3});
@ -1506,6 +1576,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExtractImagePatches(){
/*
tf.reset_default_graph()
@ -1553,6 +1625,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentProdBpSimple(){
INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT);
@ -1573,6 +1647,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulRank4() throws Exception {
Nd4j.getRandom().setSeed(12345);
@ -1608,6 +1684,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulRank4_simple(){
INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
@ -1634,6 +1712,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNthElementRank1(){
INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9});
INDArray n = Nd4j.scalar(0);
@ -1656,6 +1736,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorMmulShape(){
INDArray a = Nd4j.create(new double[]{2}).reshape(1);
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
@ -1674,6 +1756,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorMmulShape2(){
INDArray a = Nd4j.create(new double[]{2}).reshape(1);
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
@ -1682,6 +1766,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStopGradient(){
SameDiff sd = SameDiff.create();
@ -1701,6 +1787,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckNumerics(){
OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927
@ -1744,7 +1832,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testCheckNumerics2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckNumerics2(Nd4jBackend backend) {
INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4);
INDArray msg = Nd4j.scalar("My error message!");
@ -1757,6 +1847,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testHistogramFixedWidth(){
//Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf]
INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9);
@ -1775,6 +1867,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition(){
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
@ -1793,6 +1887,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testListDiff(){
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
INDArray y = Nd4j.createFromArray(3, 1);
@ -1812,7 +1908,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testDivideNoNan() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDivideNoNan(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff()
SameDiff sameDiff = SameDiff.create();
@ -1836,7 +1934,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testDigamma() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDigamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1851,7 +1951,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testFlatten() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFlatten(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1873,7 +1975,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testFusedBatchNorm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFusedBatchNorm(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();
@ -1918,7 +2022,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testIgamma() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIgamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1934,7 +2040,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testIgammaC() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIgammaC(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1951,7 +2059,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testLgamma() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLgamma(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1976,7 +2086,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testLu() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLu(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -2007,7 +2119,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testMatrixBandPart() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixBandPart(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();
@ -2037,7 +2151,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testPolygamma() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPolygamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -2053,7 +2169,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testTriangularSolve() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriangularSolve(Nd4jBackend backend) {
INDArray a = Nd4j.createFromArray(new float[]{
3.f, 0.f, 0.f, 0.f,
@ -2077,7 +2195,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testBiasAdd() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -2106,7 +2226,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testBiasAddGrad() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAddGrad(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -2126,7 +2248,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testRoll() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRoll(Nd4jBackend backend) {
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}).
@ -2146,6 +2270,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSeqMask(){
INDArray arr = Nd4j.createFromArray(1,2,3);
INDArray maxLen = Nd4j.scalar(4);

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -51,12 +53,11 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class RandomOpValidation extends BaseOpValidation {
public RandomOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test
public void testRandomOpsSDVarShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomOpsSDVarShape(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -157,7 +158,9 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
public void testRandomOpsLongShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomOpsLongShape(Nd4jBackend backend) {
List<String> failed = new ArrayList<>();
for (long[] shape : Arrays.asList(new long[]{1000}, new long[]{100, 10}, new long[]{40, 5, 5})) {
@ -283,6 +286,8 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomBinomial(){
INDArray z = Nd4j.create(new long[]{10});
@ -293,7 +298,9 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
public void testUniformRankSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUniformRankSimple(Nd4jBackend backend) {
INDArray arr = Nd4j.createFromArray(new double[]{100.0});
// OpTestCase tc = new OpTestCase(DynamicCustomOp.builder("randomuniform")
@ -325,7 +332,9 @@ public class RandomOpValidation extends BaseOpValidation {
@Test
public void testRandomExponential() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomExponential(Nd4jBackend backend) {
long length = 1_000_000;
INDArray shape = Nd4j.createFromArray(new double[]{length});
INDArray out = Nd4j.createUninitialized(new long[]{length});
@ -347,6 +356,8 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRange(){
//Technically deterministic, not random...
@ -380,6 +391,8 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllEmptyReduce(){
INDArray x = Nd4j.createFromArray(true, true, true);
All all = new All(x);
@ -389,6 +402,8 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUniformDtype(){
Nd4j.getRandom().setSeed(12345);
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
@ -417,6 +432,8 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomExponential2(){
Nd4j.getRandom().setSeed(12345);
DynamicCustomOp op = DynamicCustomOp.builder("random_exponential")

View File

@ -24,6 +24,8 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.validation.OpTestCase;
import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.linalg.api.buffer.DataType;
@ -51,10 +53,6 @@ public class ReductionBpOpValidation extends BaseOpValidation {
private DataType initialType;
public ReductionBpOpValidation(Nd4jBackend backend) {
super(backend);
}
@BeforeEach
public void before() {
Nd4j.create(1);
@ -71,14 +69,16 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@AfterEach
public void tearDown() {
public void tearDown(Nd4jBackend backend) {
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
}
@Test
public void testReduceSumBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumBP(Nd4jBackend backend) {
//Full array reduction
//reduce_sum_bp op: has 2 inputs (original pre-reduce input, and gradient at output (epsilon))
@ -104,7 +104,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testReduceSumAlongDim0BP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumAlongDim0BP(Nd4jBackend backend) {
//Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar
@ -130,7 +132,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testReduceSumAlongDim1BP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumAlongDim1BP(Nd4jBackend backend) {
//Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar
@ -158,7 +162,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test
public void testMeanBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanBP(Nd4jBackend backend) {
//dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j))
// = 1/N * dL/dOut
@ -189,7 +195,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMeanBP_Rank1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanBP_Rank1(Nd4jBackend backend) {
INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5 / 3);
@ -202,7 +210,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMeanAlongDim0BP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanAlongDim0BP(Nd4jBackend backend) {
//Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar
@ -230,7 +240,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMeanAlongDim1BP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanAlongDim1BP(Nd4jBackend backend) {
//Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar
@ -258,7 +270,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test
public void testMinBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMinBP(Nd4jBackend backend) {
//Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i
@ -297,7 +311,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMinAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMinAlongDimensionBP(Nd4jBackend backend) {
//Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i
@ -340,7 +356,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMaxBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxBP(Nd4jBackend backend) {
//Full array max reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i
@ -370,7 +388,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMaxAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxAlongDimensionBP(Nd4jBackend backend) {
//Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i
@ -413,7 +433,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testProdBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProdBP(Nd4jBackend backend) {
//Full array product reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i
@ -442,7 +464,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testProdAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProdAlongDimensionBP(Nd4jBackend backend) {
//dL/dIn_i = dL/dOut * dOut/dIn_i
// = dL/dOut * d(prod(in))/dIn_i
// = dL/dOut * (prod(in) / in_i)
@ -498,7 +522,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testStdevBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevBP(Nd4jBackend backend) {
//If out = stdev(in) then:
//dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
@ -534,7 +560,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testStdevBP_Rank1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevBP_Rank1(Nd4jBackend backend) {
INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
double stdev = preReduceInput.stdNumber(true).doubleValue();
@ -555,7 +583,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testStdevAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevAlongDimensionBP(Nd4jBackend backend) {
//If out = stdev(in) then:
//dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
@ -600,7 +630,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testVarianceBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVarianceBP(Nd4jBackend backend) {
//If out = variance(in) then:
//dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = 2*(in_i-mean)/(n-1)
@ -636,7 +668,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testVarianceAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVarianceAlongDimensionBP(Nd4jBackend backend) {
//If out = variance(in) then:
//dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = 2*(in_i-mean)/(n-1)
@ -678,7 +712,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test
public void testCumSumBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumSumBP(Nd4jBackend backend) {
//Standard case, non-reverse, non-exclusive
//dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i
// = sum_j dL/dOut_j * d(in_0 + ... + in_j)/dIn_i
@ -748,7 +784,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test
public void testNorm2Bp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2Bp(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * x/|x|_2
@ -775,7 +813,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testNorm2AlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2AlongDimensionBP(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * x/|x|_2
@ -808,7 +848,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testNorm1Bp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1Bp(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * sgn(in)
@ -835,7 +877,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testNorm1AlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1AlongDimensionBP(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * sgn(in)
@ -867,7 +911,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testNormMaxBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMaxBp(Nd4jBackend backend) {
//out = max_i (|in_i|)
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise)
@ -897,7 +943,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testNormMaxAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMaxAlongDimensionBP(Nd4jBackend backend) {
//out = max_i (|in_i|)
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise)

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -76,16 +77,13 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@Slf4j
@RunWith(Parameterized.class)
public class ReductionOpValidation extends BaseOpValidation {
public ReductionOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test
public void testStdev() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdev(Nd4jBackend backend) {
List<String> errors = new ArrayList<>();
for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE)) {
@ -111,7 +109,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testZeroCount() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeroCount(Nd4jBackend backend) {
List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 21; i++) {
SameDiff sd = SameDiff.create();
@ -145,7 +145,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test
public void testZeroFraction() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeroFraction(Nd4jBackend backend) {
List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 2; i++) {
SameDiff sd = SameDiff.create();
@ -175,7 +177,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testReductionGradientsSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradientsSimple(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
//Test reductions: final and only function
Nd4j.getRandom().setSeed(12345);
@ -344,7 +348,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testReductionGradients1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradients1(Nd4jBackend backend) {
//Test reductions: final, but *not* the only function
Nd4j.getRandom().setSeed(12345);
@ -472,7 +478,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testReductionGradients2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradients2(Nd4jBackend backend) {
//Test reductions: NON-final function
Nd4j.getRandom().setSeed(12345);
@ -650,7 +658,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test
public void testReduce3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduce3(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int d0 = 3;
@ -755,7 +765,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testMoments() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMoments(Nd4jBackend backend) {
for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) {
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -787,7 +799,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testMomentsOp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMomentsOp(Nd4jBackend backend) {
int[] axes = new int[]{0};
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -804,7 +818,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testNormalizeMomentsOp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormalizeMomentsOp(Nd4jBackend backend) {
INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
INDArray ssSum = data.sum(0);
INDArray ssSqSum = data.mul(data).sum(0);
@ -824,7 +840,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testAllAny() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllAny(Nd4jBackend backend) {
INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4);
INDArray allOnes = Nd4j.ones(DataType.FLOAT, 3, 4);
@ -852,7 +870,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testIndexAccum() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexAccum(Nd4jBackend backend) {
List<String> failed = new ArrayList<>();
List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/);
@ -941,7 +961,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test
public void testReduce3_2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduce3_2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int d0 = 3;
@ -1039,7 +1061,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testReductionsBackwards() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionsBackwards(Nd4jBackend backend) {
// for (int i = 0; i < 7; i++) {
int i=5;
{
@ -1108,6 +1132,8 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttention(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1133,6 +1159,8 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionWithMask(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1163,6 +1191,8 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionMultiHeadInputWithMask(){
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
@ -1194,6 +1224,8 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionMultiHeadInput(){
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
@ -1221,6 +1253,8 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiHeadedDotProductAttention(){
final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
@ -1272,6 +1306,8 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionWeirdInputs(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1309,6 +1345,8 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiHeadedDotProductAttentionWeirdInputs(){
final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
@ -1366,7 +1404,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
}
@Test
public void testSufficientStatisticsOp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSufficientStatisticsOp(Nd4jBackend backend) {
INDArray data = Nd4j.createFromArray(new double[]{
5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5
@ -1392,7 +1432,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testStandardDeviation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardDeviation(Nd4jBackend backend) {
for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create();
@ -1419,7 +1461,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testSquaredNorm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSquaredNorm(Nd4jBackend backend) {
for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create();
@ -1442,7 +1486,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testShannonEntropy() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShannonEntropy(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
SameDiff sameDiff = SameDiff.create();
@ -1462,7 +1508,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testEntropy() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEntropy(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1481,7 +1529,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testAMean() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAMean(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1502,7 +1552,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testMean() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMean(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1523,7 +1575,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testNorm1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1544,7 +1598,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testNorm2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1565,7 +1621,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testNormMax() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMax(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1586,7 +1644,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testSoftmaxCrossEntropyWithLogitsLoss() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxCrossEntropyWithLogitsLoss(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
@ -43,12 +45,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j
public class RnnOpValidation extends BaseOpValidation {
public RnnOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test
public void testRnnBlockCell(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRnnBlockCell(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int mb = 2;
int nIn = 3;
@ -147,7 +148,9 @@ public class RnnOpValidation extends BaseOpValidation {
@Test
public void testRnnBlockCellManualTFCompare() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRnnBlockCellManualTFCompare(Nd4jBackend backend) {
//Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS"
SameDiff sd = SameDiff.create();
@ -209,6 +212,8 @@ public class RnnOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGRUCell(){
Nd4j.getRandom().setSeed(12345);
int mb = 2;

View File

@ -28,6 +28,8 @@ import lombok.val;
import org.apache.commons.math3.linear.LUDecomposition;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -67,9 +69,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
@Slf4j
public class ShapeOpValidation extends BaseOpValidation {
public ShapeOpValidation(Nd4jBackend backend) {
super(backend);
}
/*
To test:
@ -83,7 +82,9 @@ public class ShapeOpValidation extends BaseOpValidation {
*/
@Test
public void testConcat() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat(Nd4jBackend backend) {
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
int[] concatDim = new int[]{0, 0, 0};
List<List<int[]>> origShapes = new ArrayList<>();
@ -123,7 +124,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testReshapeGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshapeGradient(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6873
int[] origShape = new int[]{3, 4, 5};
@ -159,7 +162,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testPermuteGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteGradient(Nd4jBackend backend) {
int[] origShape = new int[]{3, 4, 5};
List<String> failed = new ArrayList<>();
@ -197,6 +202,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRank(){
List<long[]> inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5});
@ -224,7 +231,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testExpandDimsGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExpandDimsGradient(Nd4jBackend backend) {
val origShape = new long[]{3, 4};
List<String> failed = new ArrayList<>();
@ -280,7 +289,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testSqueezeGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSqueezeGradient(Nd4jBackend backend) {
val origShape = new long[]{3, 4, 5};
List<String> failed = new ArrayList<>();
@ -344,7 +355,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testSliceGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSliceGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
//Order here: original shape, begin, size
@ -434,7 +447,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSliceGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
//Order here: original shape, begin, size
@ -497,7 +512,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testMerge() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMerge(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -573,7 +590,7 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test()
public void testStack() {
public void testStack(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -664,7 +681,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testUnStack() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnStack(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -752,7 +771,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testTile() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTile(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<int[]> tileArg = Arrays.asList(
@ -824,6 +845,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTileBp(){
Nd4j.getRandom().setSeed(12345);
@ -857,6 +880,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTileBp2(){
Nd4j.getRandom().setSeed(12345);
@ -891,7 +916,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testReshape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshape(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4);
SDVariable x = sameDiff.var("x", arr);
@ -907,7 +934,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testReshape2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshape2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int[] origShape = new int[]{3, 4, 5};
@ -930,7 +959,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testTranspose() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTranspose(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
SDVariable x = sameDiff.var("x", arr);
@ -942,6 +973,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTransposeOp(){
INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3);
@ -955,7 +988,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShape(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
val shape = new long[]{2, 3};
SDVariable x = sameDiff.var("x", shape);
@ -970,7 +1005,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testSize() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSize(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
val shape = new long[]{2, 3};
SDVariable x = sameDiff.var("x", DataType.FLOAT, shape);
@ -984,7 +1021,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testDiagShapeFn() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagShapeFn(Nd4jBackend backend) {
INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4);
OpTestCase op = new OpTestCase(new DiagPart(i, null));
@ -998,6 +1037,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute(){
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
INDArray exp = in.permute(0,1,2); //No op
@ -1012,6 +1053,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute2(){
for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) {
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
@ -1032,6 +1075,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConstant(){
//OpValidationSuite.ignoreFailing();
@ -1059,6 +1104,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnstackEdgeCase2(){
for( int i=0; i<3; i++ ) {
@ -1073,7 +1120,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void invertPermutation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void invertPermutation(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new float[] {3, 4, 0, 2, 1}).castTo(DataType.INT);
@ -1090,6 +1139,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherNd(){
List<INDArray> indices = new ArrayList<>();
@ -1128,7 +1179,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testReverseSequence() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReverseSequence(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
float[] input_data = new float[]{
1, 2, 3,
@ -1174,6 +1227,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant(){
OpValidationSuite.ignoreFailing(); //Gradient check failing
@ -1195,6 +1250,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeterminant22(){
OpValidationSuite.ignoreFailing(); //Gradient check failing
@ -1219,6 +1276,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant3(){
OpValidationSuite.ignoreFailing(); //Gradient checks failing
Nd4j.getRandom().setSeed(12345);
@ -1250,6 +1309,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant4(){
OpValidationSuite.ignoreFailing(); //Gradient checks failing
Nd4j.getRandom().setSeed(12345);
@ -1270,6 +1331,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentOps(){
OpValidationSuite.ignoreFailing();
//https://github.com/deeplearning4j/deeplearning4j/issues/6952
@ -1362,6 +1425,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentMean(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3);
INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2);
@ -1382,7 +1447,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testSequenceMask() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSequenceMask(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
// arr is not trainable, so it's constant in model
@ -1416,6 +1483,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeshGrid(){
List<String> failed = new ArrayList<>();
@ -1472,6 +1541,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGather(){
List<INDArray> inArrs = new ArrayList<>();
List<Integer> axis = new ArrayList<>();
@ -1541,7 +1612,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testGatherSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherSimple(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2});
SDVariable x = sameDiff.var("x", arr);
@ -1551,7 +1624,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testGatherNdSingle() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherNdSingle(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4);
INDArray arr2 = Nd4j.create(new float[]{1, 2, 3, 0, 1, 3, 1, 0, 2}, new long[]{3, 3}).castTo(DataType.INT);
@ -1570,7 +1645,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStack2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStack2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
@ -1581,7 +1658,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testParallelStack() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testParallelStack(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
@ -1593,7 +1672,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testUnStack2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnStack2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Nd4j.zeros(3, 2);
INDArray arr2 = Nd4j.ones(3, 2);
@ -1606,7 +1687,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testPermuteSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteSimple(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3));
SDVariable x = sameDiff.var("x", arr);
@ -1617,7 +1700,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testConcat2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(4, 8, 4)).reshape(1,4);
@ -1628,7 +1713,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testTile2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTile2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4));
SDVariable x = sameDiff.var("x", arr);
@ -1641,7 +1728,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testSlice2d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice2d(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create();
@ -1657,7 +1746,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testSlice3d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice3d(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create();
@ -1672,7 +1763,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSlice2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSlice2dBasic(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create();
@ -1690,7 +1783,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testStridedSliceBeginEndMask() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceBeginEndMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create();
@ -1705,7 +1800,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSliceEllipsisMask() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEllipsisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr);
@ -1722,7 +1819,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSliceNewAxisMask() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceNewAxisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr);
@ -1735,7 +1834,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSliceNewAxisMask2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceNewAxisMask2(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr);
@ -1746,7 +1847,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSliceShrinkAxisMask() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceShrinkAxisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create();
@ -1763,7 +1866,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testSizeAt_1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSizeAt_1(Nd4jBackend backend) {
val array = Nd4j.create(10, 20, 30);
val exp = Nd4j.scalar(DataType.LONG, 20);
@ -1777,6 +1882,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(){
int[] rows = new int[]{3,3,3,3};
int[] cols = new int[]{3,2,2,2};
@ -1815,6 +1922,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplit1(){
INDArray in = Nd4j.linspace(1,10,10).reshape(10);
INDArray axis = Nd4j.scalar(-1);
@ -1833,6 +1942,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplit2(){
INDArray in = Nd4j.linspace(1,24,24).reshape(3,8);
INDArray axis = Nd4j.scalar(-1);
@ -1851,6 +1962,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDistancesExec(){
//https://github.com/deeplearning4j/deeplearning4j/issues/7001
for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) {
@ -1906,6 +2019,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionShape(){
INDArray shape = Nd4j.createFromArray(4,2);
@ -1924,6 +2039,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void gatherTest(){
INDArray in = Nd4j.createFromArray(new double[][]{
{1,2,3,4,5},
@ -1943,6 +2060,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSliceShape(){
INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT);
@ -1964,6 +2083,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testWhereAllFalse(){
INDArray in = Nd4j.create(DataType.BOOL, 1917);
DynamicCustomOp op = DynamicCustomOp.builder("Where")
@ -1978,6 +2099,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherScalar(){
INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100);
INDArray indices = Nd4j.scalar(0);
@ -2002,6 +2125,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCastEmpty(){
INDArray emptyLong = Nd4j.empty(DataType.LONG);
int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h
@ -2018,6 +2143,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherEmpty(){
/*
tf.reset_default_graph()
@ -2050,6 +2177,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplitEmpty(){
/*
tf.reset_default_graph()
@ -2087,6 +2216,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatEmpty(){
/*
TF behaviour with concatenatioun of empty arrays:
@ -2136,6 +2267,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatEmpty2(){
INDArray empty10a = Nd4j.create(DataType.INT, 1, 0);
INDArray empty10b = Nd4j.create(DataType.INT, 1, 0);
@ -2168,6 +2301,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyGather(){
/*
tf.reset_default_graph()
@ -2200,6 +2335,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastDynamicShape1(){
//Test case: [2,1] and [4]: expect [2,4]
@ -2221,6 +2358,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastDynamicShape2(){
//Test case: [2,1,4] and [2,2,4]: expect [2,2,4]
@ -2243,6 +2382,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceShrinkAxis(){
INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2);
INDArray begin = Nd4j.createFromArray(2);
@ -2268,6 +2409,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEmpty(){
INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask!
@ -2290,6 +2433,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEdgeCase(){
INDArray in = Nd4j.scalar(10).reshape(1); //Int [1]
INDArray begin = Nd4j.ones(DataType.INT, 1);
@ -2315,6 +2460,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptySlice1(){
INDArray in = Nd4j.createFromArray(38);
INDArray begin = Nd4j.createFromArray(1);
@ -2334,6 +2481,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptySlice2(){
INDArray in = Nd4j.createFromArray(38);
INDArray begin = Nd4j.createFromArray(0);
@ -2353,6 +2502,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFill(){
INDArray shape = Nd4j.createFromArray(0,4);
@ -2372,6 +2523,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFill2(){
INDArray shape = Nd4j.createFromArray(0,4);
@ -2389,6 +2542,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteShapeDynamicAxis(){
DynamicCustomOp op = DynamicCustomOp.builder("permute")
@ -2418,6 +2573,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGather2(){
SameDiff sd = SameDiff.create();
SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3));
@ -2437,6 +2594,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute3(){
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0);
@ -2455,6 +2614,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute4(){
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0);
@ -2485,6 +2646,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvertPermutation(){
DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation")
.addInputs(Nd4j.createFromArray(1, 0))
@ -2492,7 +2655,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testBroadcastInt1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastInt1(Nd4jBackend backend) {
INDArray out = Nd4j.create(DataType.INT, 1);
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
@ -2505,6 +2670,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastInt2(){
INDArray out = Nd4j.create(DataType.INT, 2);
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
@ -2544,7 +2711,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testMergeMaxIndex() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeMaxIndex(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2561,7 +2730,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testTriOp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriOp(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable out = new Tri(sd, DataType.INT32, 3, 5, 2).outputVariable();
@ -2573,7 +2744,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testTriuOp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriuOp(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {7,8,9},{10,11,12}}));

View File

@ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
@ -94,9 +96,6 @@ public class TransformOpValidation extends BaseOpValidation {
private DataType initialType;
public TransformOpValidation(Nd4jBackend backend) {
super(backend);
}
@BeforeEach
public void before() {
@ -120,7 +119,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testScalarOps() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarOps(Nd4jBackend backend) {
int d0 = 2;
int d1 = 3;
int d2 = 4;
@ -217,7 +218,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testScalarMulCF() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarMulCF(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
INDArray outC = Nd4j.createUninitialized(3, 4);
@ -231,7 +234,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testScalarMulCF2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarMulCF2(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
@ -242,7 +247,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testCross() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCross(Nd4jBackend backend) {
INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3});
INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3});
@ -270,7 +277,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testSpaceToDepth() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpaceToDepth(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(1337);
int miniBatch = 128;
@ -298,7 +307,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDepthToSpace() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthToSpace(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(1337);
int miniBatch = 128;
@ -325,7 +336,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testBatchToSpace() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBatchToSpace(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(1337);
@ -362,7 +375,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testSpaceToBatch() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpaceToBatch(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(7331);
@ -400,7 +415,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDynamicPartition() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{4, 3, 5, 7, 8, 0});
@ -440,7 +457,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDynamicPartition2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition2(Nd4jBackend backend) {
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
@ -458,7 +477,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDynamicStitch() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicStitch(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{5, 1, 3}, new long[]{3});
@ -495,7 +516,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDiag() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiag(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{1, 2}, new int[]{2});
@ -521,7 +544,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDiagPart() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagPart(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
INDArray input = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
@ -540,7 +565,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testEye() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(Nd4jBackend backend) {
int[] rows = new int[]{3, 3, 3, 3};
int[] cols = new int[]{3, 2, 2, 2};
int[][] batch = new int[][]{{}, {}, {4}, {3, 3}};
@ -574,7 +601,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testEyeShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEyeShape(Nd4jBackend backend) {
DynamicCustomOp dco = DynamicCustomOp.builder("eye")
.addIntegerArguments(3, 3)
//.addIntegerArguments(-99,3,3) //Also fails
@ -586,7 +615,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testTransforms() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTransforms(Nd4jBackend backend) {
//Test transforms (non-pairwise)
Nd4j.getRandom().setSeed(12345);
@ -1074,7 +1105,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testPairwiseTransforms() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPairwiseTransforms(Nd4jBackend backend) {
/*
add, sub, mul, div, rsub, rdiv
eq, neq, gt, lt, gte, lte, or, and, xor
@ -1258,7 +1291,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testIsX() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIsX(Nd4jBackend backend) {
List<String> failed = new ArrayList<>();
for (int i = 0; i < 4; i++) {
@ -1313,7 +1348,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testReplaceWhereScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReplaceWhereScalar(Nd4jBackend backend) {
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
log.info("Testing condition: " + c.getClass().getSimpleName());
@ -1335,7 +1372,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testReplaceWhereArray() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReplaceWhereArray(Nd4jBackend backend) {
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
INDArray inArr = Nd4j.rand(3, 4);
@ -1358,7 +1397,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testLogGrad() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogGrad(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE));
SDVariable log = sameDiff.math().log(input);
@ -1369,7 +1410,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testSigmoidBackwards() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSigmoidBackwards(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
Map<String, INDArray> inputs = new HashMap<>();
@ -1387,7 +1430,9 @@ public class TransformOpValidation extends BaseOpValidation {
/* @Test
public void testDepth() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepth(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
SDVariable x = sameDiff.one("one",new long[]{2,2});
assertEquals(0,x.depth());
@ -1396,7 +1441,9 @@ public class TransformOpValidation extends BaseOpValidation {
}*/
@Test
public void testRank0EdgeCase() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRank0EdgeCase(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4})));
double d0 = v1.eval().getDouble(0);
@ -1409,7 +1456,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testAtan2BroadcastShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAtan2BroadcastShape(Nd4jBackend backend) {
INDArray arr1 = Nd4j.create(new long[]{3, 1, 4});
INDArray arr2 = Nd4j.create(new long[]{1, 2, 4});
@ -1424,7 +1473,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testBooleanAnd() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBooleanAnd(Nd4jBackend backend) {
Nd4j.setDataType(DataType.FLOAT);
INDArray arr1 = Nd4j.create(new long[]{3, 4});
INDArray arr2 = Nd4j.create(new long[]{3, 4});
@ -1438,7 +1489,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testScatterOpsScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterOpsScalar(Nd4jBackend backend) {
for (String s : new String[]{"add", "sub", "mul", "div"}) {
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
INDArray indices = Nd4j.scalar(5);
@ -1483,7 +1536,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540")
@Test
public void testPad() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPad(Nd4jBackend backend) {
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG);
INDArray value = Nd4j.scalar(10.0);
@ -1510,7 +1565,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testMirrorPad() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPad(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
@ -1543,7 +1600,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMirrorPad2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPad2(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
@ -1569,7 +1628,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMirrorPadSymmetric() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPadSymmetric(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT);
@ -1596,7 +1657,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testUnique() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnique(Nd4jBackend backend) {
INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4});
INDArray expUnique = Nd4j.create(new double[]{3, 4, 1, 0, 2});
@ -1618,7 +1681,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testTopK() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopK(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //Can't assume sorted here
INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8});
@ -1647,7 +1712,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testTopK1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopK1(Nd4jBackend backend) {
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
INDArray k = Nd4j.scalar(1);
INDArray outValue = Nd4j.create(DataType.DOUBLE, 1);
@ -1668,7 +1735,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testInTopK() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInTopK(Nd4jBackend backend) {
for (int k = 4; k >= 1; k--) {
log.info("Testing: k=" + k);
INDArray in = Nd4j.linspace(1, 20, 20, DataType.DOUBLE).reshape(4, 5);
@ -1709,7 +1778,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testZeta() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeta(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182
INDArray x = Nd4j.rand(3, 4).addi(1.0);
INDArray q = Nd4j.rand(3, 4);
@ -1726,7 +1797,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMaxEmptyScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxEmptyScalar(Nd4jBackend backend) {
INDArray empty = Nd4j.empty(DataType.FLOAT);
INDArray scalar = Nd4j.scalar(1.0f);
@ -1743,7 +1816,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testBroadcastEmpty() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastEmpty(Nd4jBackend backend) {
// Nd4j.getExecutioner().enableVerboseMode(true);
// Nd4j.getExecutioner().enableDebugMode(true);
//Check broadcast behaviour with empty arrays. The idea is to match TF import behaviour, for import
@ -1833,7 +1908,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testStandardize() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardize(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4});
final int[] axis = new int[]{1};
@ -1854,7 +1931,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testStandardizeOP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardizeOP(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4});
final int[] axis = new int[]{1};
@ -1869,7 +1948,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testStandardizeNoDeviation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardizeNoDeviation(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4});
for (int i = 0; i < 4; i++) {
random.putScalar(1, i, 7);
@ -1895,7 +1976,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMatMulTensor() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatMulTensor(Nd4jBackend backend) {
final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5});
final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6});
@ -1915,7 +1998,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMatMulTensorTranspose() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatMulTensorTranspose(Nd4jBackend backend) {
for (boolean transposeA : new boolean[]{false, true}) {
for (boolean transposeB : new boolean[]{false, true}) {
for (boolean transposeResult : new boolean[]{false, true}) {
@ -2008,7 +2093,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testSoftmaxCF() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxCF(Nd4jBackend backend) {
INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5);
INDArray arrF = arrC.dup('f');
@ -2029,7 +2116,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testLogSumExp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogSumExp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4);
SameDiff sd = SameDiff.create();
@ -2044,7 +2133,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testLogSumExp2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogSumExp2(Nd4jBackend backend) {
for (int dim = 0; dim <= 2; dim++) {
Nd4j.getRandom().setSeed(12345);
@ -2065,7 +2156,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testCRELU() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCRELU(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2);
@ -2084,7 +2177,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testClipByAvgNorm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByAvgNorm(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2, 2);
@ -2105,7 +2200,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testEmbeddingLookup() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmbeddingLookup(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2118,7 +2215,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testImageResize() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testImageResize(Nd4jBackend backend) {
//TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea
@ -2160,7 +2259,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testMaximumBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaximumBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2177,7 +2278,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMergeAddBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeAddBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2194,7 +2297,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMergeMaxBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeMaxBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2212,7 +2317,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testMergeAvgBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeAvgBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2229,7 +2336,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testReverseBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReverseBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2243,7 +2352,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testUpsampling3dBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUpsampling3dBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
for (boolean dataformat : new boolean[]{true, false}) {

View File

@ -24,8 +24,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
@ -36,11 +37,8 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.factory.Nd4jBackend;
public class ConvConfigTests extends BaseNd4jTest {
public class ConvConfigTests extends BaseNd4jTestWithBackends {
public ConvConfigTests(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -48,7 +46,9 @@ public class ConvConfigTests extends BaseNd4jTest {
}
@Test
public void testDeConv2D(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv2D(Nd4jBackend backend){
DeConv2DConfig.builder().kH(2).kW(4).build();
try{
@ -109,7 +109,9 @@ public class ConvConfigTests extends BaseNd4jTest {
}
@Test
public void testConv2D(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2D(Nd4jBackend backend){
Conv2DConfig.builder().kH(2).kW(4).build();
try{
@ -170,7 +172,9 @@ public class ConvConfigTests extends BaseNd4jTest {
}
@Test
public void testPooling2D(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPooling2D(Nd4jBackend backend){
Pooling2DConfig.builder().kH(2).kW(4).build();
try{
@ -231,7 +235,9 @@ public class ConvConfigTests extends BaseNd4jTest {
}
@Test
public void testDeConv3D(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv3D(Nd4jBackend backend){
DeConv3DConfig.builder().kH(2).kW(4).kD(3).build();
try{
@ -320,7 +326,9 @@ public class ConvConfigTests extends BaseNd4jTest {
}
@Test
public void testConv3D(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3D(Nd4jBackend backend){
Conv3DConfig.builder().kH(2).kW(4).kD(3).build();
try{
@ -411,7 +419,9 @@ public class ConvConfigTests extends BaseNd4jTest {
@Test
public void testPooling3D(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPooling3D(Nd4jBackend backend){
Pooling3DConfig.builder().kH(2).kW(4).kD(3).build();
try{
@ -500,6 +510,8 @@ public class ConvConfigTests extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1D(){
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();

View File

@ -23,8 +23,10 @@ package org.nd4j.autodiff.samediff;
import lombok.val;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Disabled("AB 2019/05/21 - JVM Crash on ppc64 - Issue #7657")
public class FailingSameDiffTests extends BaseNd4jTest {
public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
public FailingSameDiffTests(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -52,7 +51,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}
@Test
public void testEye(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(Nd4jBackend backend){
//OpValidationSuite.ignoreFailing();
INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3});
List<INDArray> stack = new ArrayList<>();
@ -68,7 +69,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}
@Test
public void testEyeShape(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEyeShape(Nd4jBackend backend){
val dco = DynamicCustomOp.builder("eye")
.addIntegerArguments(3,3)
//.addIntegerArguments(-99,3,3) //Also fails
@ -80,7 +83,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}
@Test
public void testExecutionDifferentShapesTransform(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecutionDifferentShapesTransform(Nd4jBackend backend){
OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4));
@ -101,7 +106,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}
@Test
public void testDropout() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDropout(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create();
double p = 0.5;
@ -114,7 +121,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}
@Test
public void testExecutionDifferentShapesDynamicCustom(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecutionDifferentShapesDynamicCustom(Nd4jBackend backend){
OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create();

View File

@ -26,13 +26,15 @@ import org.apache.commons.io.IOUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.graph.FlatConfiguration;
import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
@ -70,11 +72,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
public class FlatBufferSerdeTest extends BaseNd4jTest {
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
public FlatBufferSerdeTest(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -84,7 +83,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test
public void testBasic(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create();
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
@ -139,7 +140,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
}
@Test
public void testSimple(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
for( int i = 0; i < 10; i++ ) {
for(boolean execFirst : new boolean[]{false, true}) {
log.info("Starting test: i={}, execFirst={}", i, execFirst);
@ -268,7 +271,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test
public void testTrainingSerde(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
//Ensure 2 things:
//1. Training config is serialized/deserialized correctly
@ -352,7 +357,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test
public void pooling3DSerialization(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void pooling3DSerialization(Nd4jBackend backend){
SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28);
@ -372,7 +379,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
}
@Test
public void pooling3DSerialization2(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void pooling3DSerialization2(Nd4jBackend backend){
SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28);

View File

@ -22,12 +22,14 @@ package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
import org.nd4j.autodiff.samediff.transform.OpPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraph;
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
public class GraphTransformUtilTests extends BaseNd4jTest {
public class GraphTransformUtilTests extends BaseNd4jTestWithBackends {
public GraphTransformUtilTests(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -54,7 +53,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
}
@Test
public void testBasic(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasic(Nd4jBackend backend){
SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 32);
@ -93,7 +94,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
}
@Test
public void testSubgraphReplace1(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSubgraphReplace1(Nd4jBackend backend){
SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 4);

View File

@ -21,8 +21,10 @@
package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -32,11 +34,8 @@ import java.lang.reflect.Field;
import static org.junit.jupiter.api.Assertions.*;
public class MemoryMgrTest extends BaseNd4jTest {
public class MemoryMgrTest extends BaseNd4jTestWithBackends {
public MemoryMgrTest(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -44,7 +43,9 @@ public class MemoryMgrTest extends BaseNd4jTest {
}
@Test
public void testArrayReuseTooLarge() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArrayReuseTooLarge(Nd4jBackend backend) throws Exception {
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
Field f = ArrayCacheMemoryMgr.class.getDeclaredField("maxCacheBytes");
@ -116,7 +117,9 @@ public class MemoryMgrTest extends BaseNd4jTest {
}
@Test
public void testManyArrays(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testManyArrays(Nd4jBackend backend){
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
for( int i = 0; i < 1000; i++) {

View File

@ -21,9 +21,11 @@
package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,11 +37,8 @@ import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class NameScopeTests extends BaseNd4jTest {
public class NameScopeTests extends BaseNd4jTestWithBackends {
public NameScopeTests(Nd4jBackend b){
super(b);
}
@Override
public char ordering() {
@ -47,7 +46,9 @@ public class NameScopeTests extends BaseNd4jTest {
}
@Test
public void testVariableNameScopesBasic(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVariableNameScopesBasic(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable v = sd.var("x");
@ -73,7 +74,9 @@ public class NameScopeTests extends BaseNd4jTest {
}
@Test
public void testOpFieldsAndNames(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpFieldsAndNames(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable x = sd.var("x", DataType.FLOAT, 1);
@ -151,7 +154,9 @@ public class NameScopeTests extends BaseNd4jTest {
}
@Test
public void testNoNesting(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoNesting(Nd4jBackend backend) {
SameDiff SD = SameDiff.create();
SDVariable a = SD.constant(4);
@ -168,7 +173,9 @@ public class NameScopeTests extends BaseNd4jTest {
}
@Test
public void testNoTesting2(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoTesting2(Nd4jBackend backend) {
SameDiff SD = SameDiff.create();
SDVariable a = SD.constant(4);

View File

@ -21,21 +21,16 @@
package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.imports.tfgraphs.TFGraphTestZooModels;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.io.File;
import java.nio.file.Path;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
@ -55,7 +50,9 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
}
@Test
public void testSimple() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(Nd4jBackend backend) throws Exception {
int nThreads = 4;
int nRuns = 1000;
@ -103,48 +100,6 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
}
}
@Test
@Disabled //2020/03/24 AB - https://github.com/eclipse/deeplearning4j/issues/8802
public void testMobilenet(@TempDir Path testDir) throws Exception {
TFGraphTestZooModels.currentTestDir = testDir.toFile();
File f = Resources.asFile("tf_graphs/zoo_models/mobilenet_v2_1.0_224/tf_model.txt");
SameDiff sd = TFGraphTestZooModels.LOADER.apply(f, "mobilenet_v2_1.0_224");
// System.out.println(sd.summary());
int nThreads = 4;
int nRuns = 30;
INDArray[] inputArrs = new INDArray[nThreads];
INDArray[] expOut = new INDArray[nThreads];
for( int i=0; i<nThreads; i++ ){
if(i == 0 || i > 2)
inputArrs[i] = Nd4j.rand(DataType.FLOAT, 1, 224, 224, 3);
else if(i == 1)
inputArrs[i] = Nd4j.zeros(DataType.FLOAT, 1, 224, 224, 3);
else if(i == 2)
inputArrs[i] = Nd4j.ones(DataType.FLOAT, 1, 224, 224, 3);
expOut[i] = sd.outputSingle(Collections.singletonMap("input", inputArrs[i]), "MobilenetV2/Predictions/Reshape_1");
Nd4j.getExecutioner().commit();
}
AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads];
AtomicInteger[] counters = new AtomicInteger[nThreads];
Semaphore s = new Semaphore(nThreads);
CountDownLatch latch = new CountDownLatch(nThreads);
doTest(sd, nThreads, nRuns, inputArrs, expOut, "input", "MobilenetV2/Predictions/Reshape_1", failuresByThread, counters, s, latch);
s.release(nThreads);
latch.await();
for(int i = 0; i < nThreads; i++) {
assertFalse( failuresByThread[i].get(),"Thread " + i + " failed");
}
for(int i = 0; i < nThreads; i++) {
assertEquals( nRuns, counters[i].get(),"Thread " + i + " number of runs");
}
}
public static void doTest(SameDiff sd, int nThreads, int nRuns, INDArray[] inputArrs, INDArray[] expOut,

View File

@ -21,7 +21,9 @@
package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -31,14 +33,13 @@ import org.nd4j.linalg.learning.config.Sgd;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class SameDiffOutputTest extends BaseNd4jTest {
public class SameDiffOutputTest extends BaseNd4jTestWithBackends {
public SameDiffOutputTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void outputTest(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void outputTest(Nd4jBackend backend){
DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10));
SameDiff sd = SameDiff.create();

View File

@ -21,7 +21,9 @@
package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -35,11 +37,8 @@ import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.assertNull;
import static org.junit.jupiter.api.Assertions.*;
public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
public SameDiffSpecifiedLossVarsTests(Nd4jBackend b){
super(b);
}
@Override
public char ordering() {
@ -47,7 +46,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
}
@Test
public void testSpecifiedLoss1(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedLoss1(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4);
ph1.setArray(Nd4j.create(DataType.FLOAT, 3, 4));
@ -68,7 +69,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
}
@Test
public void testSpecifiedLoss2(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedLoss2(Nd4jBackend backend) {
for( int i=0; i<2; i++ ) {
SameDiff sd = SameDiff.create();
SDVariable ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4);
@ -121,7 +124,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
@Test
public void testTrainingDifferentLosses(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingDifferentLosses(Nd4jBackend backend) {
//Net with 2 losses: train on the first one, then change losses
//Also check that if modifying via add/setLossVariables the training config changes

View File

@ -30,11 +30,13 @@ import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.impl.ScoreListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -55,14 +57,13 @@ import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.weightinit.impl.XavierInitScheme;
@Slf4j
public class SameDiffTrainingTest extends BaseNd4jTest {
public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
public SameDiffTrainingTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void irisTrainingSanityCheck() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingSanityCheck(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize();
@ -134,7 +135,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test
public void irisTrainingEvalTest() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingEvalTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize();
@ -184,7 +187,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test
public void irisTrainingValidationTest() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingValidationTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize();
@ -239,6 +244,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingMixedDtypes(){
for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) {
@ -301,7 +308,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
}
@Test
public void simpleClassification() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void simpleClassification(Nd4jBackend backend) {
double learning_rate = 0.001;
int seed = 7;
org.nd4j.linalg.api.rng.Random rng = Nd4j.getRandom();
@ -348,6 +357,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingEvalVarNotReqForLoss(){
//If a variable is not required for the loss - normally it won't be calculated
//But we want to make sure it IS calculated here - so we can perform evaluation on it

View File

@ -25,11 +25,13 @@ import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.IrisDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -48,11 +50,8 @@ import java.util.concurrent.TimeUnit;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class CheckpointListenerTest extends BaseNd4jTest {
public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
public CheckpointListenerTest(Nd4jBackend backend){
super(backend);
}
@Override
public char ordering(){
@ -96,7 +95,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
@Test
public void testCheckpointEveryEpoch(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();
@ -130,7 +131,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
}
@Test
public void testCheckpointEvery5Iter(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();
@ -169,7 +172,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
@Test
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();
@ -213,7 +218,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
}
@Test
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();

View File

@ -21,12 +21,13 @@
package org.nd4j.autodiff.samediff.listeners;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -34,14 +35,13 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.Adam;
public class ExecDebuggingListenerTest extends BaseNd4jTest {
public class ExecDebuggingListenerTest extends BaseNd4jTestWithBackends {
public ExecDebuggingListenerTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void testExecDebugListener(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecDebugListener(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);

View File

@ -21,6 +21,8 @@
package org.nd4j.autodiff.samediff.listeners;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Listener;
@ -38,7 +40,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
@ -61,11 +63,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
public class ListenerTest extends BaseNd4jTest {
public class ListenerTest extends BaseNd4jTestWithBackends {
public ListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -73,7 +72,9 @@ public class ListenerTest extends BaseNd4jTest {
}
@Test
public void irisHistoryTest() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisHistoryTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize();
@ -136,6 +137,8 @@ public class ListenerTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testListenerCalls(){
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
@ -273,7 +276,9 @@ public class ListenerTest extends BaseNd4jTest {
}
@Test
public void testCustomListener() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCustomListener(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);

View File

@ -26,11 +26,13 @@ import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -45,11 +47,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
public class ProfilingListenerTest extends BaseNd4jTest {
public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
public ProfilingListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -59,7 +58,9 @@ public class ProfilingListenerTest extends BaseNd4jTest {
@Test
public void testProfilingListenerSimple(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
@ -108,18 +109,24 @@ public class ProfilingListenerTest extends BaseNd4jTest {
/*
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfile(){
File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json");
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfileDir(){
File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfileDir2(){
File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);

View File

@ -27,6 +27,8 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
@ -41,7 +43,7 @@ import org.nd4j.graph.UIInfoType;
import org.nd4j.graph.UIOp;
import org.nd4j.graph.UIVariable;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -60,11 +62,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
public class FileReadWriteTests extends BaseNd4jTest {
public class FileReadWriteTests extends BaseNd4jTestWithBackends {
public FileReadWriteTests(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -81,7 +80,9 @@ public class FileReadWriteTests extends BaseNd4jTest {
}
@Test
public void testSimple(@TempDir Path testDir) throws IOException {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
SameDiff sd = SameDiff.create();
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
SDVariable sum = v.sum();
@ -185,7 +186,9 @@ public class FileReadWriteTests extends BaseNd4jTest {
}
@Test
public void testNullBinLabels(@TempDir Path testDir) throws Exception{
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
File dir = testDir.toFile();
File f = new File(dir, "temp.bin");
LogFileWriter w = new LogFileWriter(f);

View File

@ -25,6 +25,8 @@ import com.google.flatbuffers.Table;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.impl.UIListener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -34,7 +36,7 @@ import org.nd4j.graph.UIEvent;
import org.nd4j.graph.UIGraphStructure;
import org.nd4j.graph.UIStaticInfoRecord;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.IrisDataSetIterator;
@ -51,11 +53,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
public class UIListenerTest extends BaseNd4jTest {
public class UIListenerTest extends BaseNd4jTestWithBackends {
public UIListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -65,7 +64,9 @@ public class UIListenerTest extends BaseNd4jTest {
@Test
public void testUIListenerBasic(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
Nd4j.getRandom().setSeed(12345);
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -101,7 +102,9 @@ public class UIListenerTest extends BaseNd4jTest {
}
@Test
public void testUIListenerContinue(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet();
@ -192,7 +195,9 @@ public class UIListenerTest extends BaseNd4jTest {
}
@Test
public void testUIListenerBadContinue(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet();

View File

@ -23,18 +23,17 @@ package org.nd4j.evaluation;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.custom.CustomEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.primitives.Pair;
public class CustomEvaluationTest extends BaseNd4jTest {
public class CustomEvaluationTest extends BaseNd4jTestWithBackends {
public CustomEvaluationTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -42,8 +41,10 @@ public class CustomEvaluationTest extends BaseNd4jTest {
}
@Test
public void customEvalTest(){
CustomEvaluation accuracyEval = new CustomEvaluation<Pair<Number, Long>>(
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void customEvalTest(Nd4jBackend backend){
CustomEvaluation accuracyEval = new CustomEvaluation<>(
(labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)),
CustomEvaluation.mergeConcatenate());

View File

@ -21,6 +21,8 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -29,17 +31,14 @@ import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
public class EmptyEvaluationTests extends BaseNd4jTest {
public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
public EmptyEvaluationTests(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -47,7 +46,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyEvaluation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluation (Nd4jBackend backend) {
Evaluation e = new Evaluation();
System.out.println(e.stats());
@ -62,7 +63,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyRegressionEvaluation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyRegressionEvaluation (Nd4jBackend backend) {
RegressionEvaluation re = new RegressionEvaluation();
re.stats();
@ -76,7 +79,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyEvaluationBinary() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluationBinary(Nd4jBackend backend) {
EvaluationBinary eb = new EvaluationBinary();
eb.stats();
@ -91,7 +96,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyROC() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROC(Nd4jBackend backend) {
ROC roc = new ROC();
roc.stats();
@ -106,7 +113,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyROCBinary() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROCBinary(Nd4jBackend backend) {
ROCBinary rb = new ROCBinary();
rb.stats();
@ -121,7 +130,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyROCMultiClass() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROCMultiClass(Nd4jBackend backend) {
ROCMultiClass r = new ROCMultiClass();
r.stats();
@ -136,7 +147,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyEvaluationCalibration() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluationCalibration(Nd4jBackend backend) {
EvaluationCalibration ec = new EvaluationCalibration();
ec.stats();

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
@ -36,11 +38,8 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class EvalCustomThreshold extends BaseNd4jTest {
public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
public EvalCustomThreshold(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -48,7 +47,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
}
@Test
public void testEvaluationCustomBinaryThreshold() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCustomBinaryThreshold(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
//Sanity checks: 0.5 threshold for 1-output and 2-output binary cases
@ -114,7 +115,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
}
@Test
public void testEvaluationCostArray() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCostArray(Nd4jBackend backend) {
int nExamples = 20;
@ -162,7 +165,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
}
@Test
public void testEvaluationBinaryCustomThreshold() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryCustomThreshold(Nd4jBackend backend) {
//Sanity check: same results for 0.5 threshold vs. default (no threshold)
int nExamples = 20;

View File

@ -21,6 +21,8 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -31,7 +33,7 @@ import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class EvalJsonTest extends BaseNd4jTest {
public class EvalJsonTest extends BaseNd4jTestWithBackends {
public EvalJsonTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -54,7 +53,9 @@ public class EvalJsonTest extends BaseNd4jTest {
}
@Test
public void testSerdeEmpty() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerdeEmpty(Nd4jBackend backend) {
boolean print = false;
IEvaluation[] arr = new IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10),
@ -74,7 +75,9 @@ public class EvalJsonTest extends BaseNd4jTest {
}
@Test
public void testSerde() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerde(Nd4jBackend backend) {
boolean print = false;
Nd4j.getRandom().setSeed(12345);
@ -122,7 +125,9 @@ public class EvalJsonTest extends BaseNd4jTest {
}
@Test
public void testSerdeExactRoc() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerdeExactRoc(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
boolean print = false;
@ -200,7 +205,9 @@ public class EvalJsonTest extends BaseNd4jTest {
}
@Test
public void testJsonYamlCurves() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJsonYamlCurves(Nd4jBackend backend) {
ROC roc = new ROC(0);
INDArray evalLabel =
@ -252,7 +259,9 @@ public class EvalJsonTest extends BaseNd4jTest {
}
@Test
public void testJsonWithCustomThreshold() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJsonWithCustomThreshold(Nd4jBackend backend) {
//Evaluation - binary threshold
Evaluation e = new Evaluation(0.25);

View File

@ -21,8 +21,10 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -39,11 +41,8 @@ import static org.junit.jupiter.api.Assertions.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
public class EvalTest extends BaseNd4jTest {
public class EvalTest extends BaseNd4jTestWithBackends {
public EvalTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -52,7 +51,9 @@ public class EvalTest extends BaseNd4jTest {
@Test
public void testEval() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEval(Nd4jBackend backend) {
int classNum = 5;
Evaluation eval = new Evaluation (classNum);
@ -91,7 +92,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testEval2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEval2(Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
Evaluation first = null;
@ -150,7 +153,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testStringListLabels() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStringListLabels(Nd4jBackend backend) {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
@ -167,7 +172,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testStringHashLabels() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStringHashLabels(Nd4jBackend backend) {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
@ -184,7 +191,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testEvalMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalMasking(Nd4jBackend backend) {
int miniBatch = 5;
int nOut = 3;
int tsLength = 6;
@ -251,7 +260,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testFalsePerfectRecall() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFalsePerfectRecall(Nd4jBackend backend) {
int testSize = 100;
int numClasses = 5;
int winner = 1;
@ -284,7 +295,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testEvaluationMerging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationMerging(Nd4jBackend backend) {
int nRows = 20;
int nCols = 3;
@ -358,7 +371,9 @@ public class EvalTest extends BaseNd4jTest {
@Test
public void testSingleClassBinaryClassification() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleClassBinaryClassification(Nd4jBackend backend) {
Evaluation eval = new Evaluation(1);
@ -387,7 +402,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testEvalInvalid() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalInvalid(Nd4jBackend backend) {
Evaluation e = new Evaluation(5);
e.eval(0, 1);
e.eval(1, 0);
@ -400,7 +417,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testEvalMethods() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalMethods(Nd4jBackend backend) {
//Check eval(int,int) vs. eval(INDArray,INDArray)
Evaluation e1 = new Evaluation(4);
@ -443,7 +462,9 @@ public class EvalTest extends BaseNd4jTest {
@Test
public void testTopNAccuracy() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopNAccuracy(Nd4jBackend backend) {
Evaluation e = new Evaluation(null, 3);
@ -504,7 +525,9 @@ public class EvalTest extends BaseNd4jTest {
@Test
public void testTopNAccuracyMerging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopNAccuracyMerging(Nd4jBackend backend) {
Evaluation e1 = new Evaluation(null, 3);
Evaluation e2 = new Evaluation(null, 3);
@ -552,7 +575,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testBinaryCase() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBinaryCase(Nd4jBackend backend) {
INDArray ones10 = Nd4j.ones(10, 1);
INDArray ones4 = Nd4j.ones(4, 1);
INDArray zeros4 = Nd4j.zeros(4, 1);
@ -581,7 +606,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testF1FBeta_MicroMacroAveraging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testF1FBeta_MicroMacroAveraging(Nd4jBackend backend) {
//Confusion matrix: rows = actual, columns = predicted
//[3, 1, 0]
//[2, 2, 1]
@ -722,7 +749,9 @@ public class EvalTest extends BaseNd4jTest {
@Test
public void testConfusionMatrixStats() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrixStats(Nd4jBackend backend) {
Evaluation e = new Evaluation();
@ -743,6 +772,8 @@ public class EvalTest extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalBinaryMetrics(){
Evaluation ePosClass1_nOut2 = new Evaluation(2, 1);
@ -864,6 +895,8 @@ public class EvalTest extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrixString(){
Evaluation e = new Evaluation(Arrays.asList("a","b","c"));
@ -914,6 +947,8 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationNaNs(){
Evaluation e = new Evaluation();
@ -929,6 +964,8 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345);
@ -1023,6 +1060,8 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLabelReset(){
Map<Integer,String> m = new HashMap<>();
@ -1056,6 +1095,8 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalStatsBinaryCase(){
//Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,11 +40,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.nd4j.evaluation.classification.EvaluationBinary.Metric.*;
public class EvaluationBinaryTest extends BaseNd4jTest {
public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
public EvaluationBinaryTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -50,7 +49,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinary() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary(Nd4jBackend backend) {
//Compare EvaluationBinary to Evaluation class
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
EvaluationBinary first = null;
@ -136,7 +137,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinaryMerging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryMerging(Nd4jBackend backend) {
int nOut = 4;
int[] shape1 = {30, nOut};
int[] shape2 = {50, nOut};
@ -163,7 +166,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinaryPerOutputMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryPerOutputMasking(Nd4jBackend backend) {
//Provide a mask array: "ignore" the masked steps
@ -206,7 +211,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testTimeSeriesEval() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTimeSeriesEval(Nd4jBackend backend) {
int[] shape = {2, 4, 3};
Nd4j.getRandom().setSeed(12345);
@ -230,7 +237,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinaryWithROC() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryWithROC(Nd4jBackend backend) {
//Simple test for nested ROCBinary in EvaluationBinary
Nd4j.getRandom().setSeed(12345);
@ -247,7 +256,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
@Test
public void testEvaluationBinary3d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -281,7 +292,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinary4d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -315,7 +328,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinary3dMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -376,7 +391,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinary4dMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -21,8 +21,10 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,11 +41,8 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
public class EvaluationCalibrationTest extends BaseNd4jTest {
public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
public EvaluationCalibrationTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering () {
@ -51,7 +50,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
}
@Test
public void testReliabilityDiagram() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReliabilityDiagram (Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
EvaluationCalibration first = null;
@ -143,7 +144,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
}
@Test
public void testLabelAndPredictionCounts() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLabelAndPredictionCounts (Nd4jBackend backend) {
int minibatch = 50;
int nClasses = 3;
@ -171,7 +174,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
}
@Test
public void testResidualPlots() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResidualPlots (Nd4jBackend backend) {
int minibatch = 50;
int nClasses = 3;
@ -272,6 +277,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345);
@ -366,7 +373,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
}
@Test
public void testEvaluationCalibration3d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCalibration3d (Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -398,7 +407,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
}
@Test
public void testEvaluationCalibration3dMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCalibration3dMasking (Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);

View File

@ -23,6 +23,8 @@ package org.nd4j.evaluation;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -30,17 +32,14 @@ import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
public class NewInstanceTest extends BaseNd4jTest {
public class NewInstanceTest extends BaseNd4jTestWithBackends {
public NewInstanceTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -48,7 +47,9 @@ public class NewInstanceTest extends BaseNd4jTest {
}
@Test
public void testNewInstances() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewInstances(Nd4jBackend backend) {
boolean print = true;
Nd4j.getRandom().setSeed(12345);

View File

@ -21,10 +21,12 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,11 +41,7 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class ROCBinaryTest extends BaseNd4jTest {
public ROCBinaryTest(Nd4jBackend backend) {
super(backend);
}
public class ROCBinaryTest extends BaseNd4jTestWithBackends {
@Override
public char ordering() {
@ -51,7 +49,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
}
@Test
public void testROCBinary() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary(Nd4jBackend backend) {
//Compare ROCBinary to ROC class
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
@ -146,7 +146,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
}
@Test
public void testRocBinaryMerging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBinaryMerging(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact
int nOut = 4;
int[] shape1 = {30, nOut};
@ -176,7 +178,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
@Test
public void testROCBinaryPerOutputMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact
@ -216,7 +220,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
@Test
public void testROCBinary3d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -250,7 +256,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
}
@Test
public void testROCBinary4d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -284,7 +292,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
}
@Test
public void testROCBinary3dMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -345,7 +355,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
}
@Test
public void testROCBinary4dMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -21,12 +21,14 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -39,11 +41,8 @@ import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
public class ROCTest extends BaseNd4jTest {
public class ROCTest extends BaseNd4jTestWithBackends {
public ROCTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -84,7 +83,9 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testRocBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBasic(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax)
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
@ -127,7 +128,9 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testRocBasicSingleClass() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBasicSingleClass(Nd4jBackend backend) {
//1 output here - single probability value (sigmoid)
//add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
@ -165,7 +168,9 @@ public class ROCTest extends BaseNd4jTest {
@Test
public void testRoc() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRoc(Nd4jBackend backend) {
//Previous tests allowed for a perfect classifier with right threshold...
INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});
@ -250,7 +255,9 @@ public class ROCTest extends BaseNd4jTest {
@Test
public void testRocTimeSeriesNoMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
//Same as first test...
//2 outputs here - probability distribution over classes (softmax)
@ -297,7 +304,9 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testRocTimeSeriesMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax)
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
@ -347,7 +356,9 @@ public class ROCTest extends BaseNd4jTest {
@Test
public void testCompareRocAndRocMultiClass() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
//For 2 class case: ROC and Multi-class ROC should be the same...
@ -377,7 +388,9 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testCompare2Vs3Classes() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCompare2Vs3Classes(Nd4jBackend backend) {
//ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together...
//Both methods implement one vs. all ROC/AUC in different ways
@ -426,7 +439,9 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testROCMerging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMerging(Nd4jBackend backend) {
int nArrays = 10;
int minibatch = 64;
int nROCs = 3;
@ -471,7 +486,9 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testROCMerging2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMerging2(Nd4jBackend backend) {
int nArrays = 10;
int minibatch = 64;
int exactAllocBlockSize = 10;
@ -516,7 +533,9 @@ public class ROCTest extends BaseNd4jTest {
@Test
public void testROCMultiMerging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMultiMerging(Nd4jBackend backend) {
int nArrays = 10;
int minibatch = 64;
@ -564,7 +583,9 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testAUCPrecisionRecall() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAUCPrecisionRecall(Nd4jBackend backend) {
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
//at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0
//at threshold 0.25 to 0.33: tp=2, fp=0, fn=0, tn=1 prec=2/2=1, recall=2/2=1
@ -611,7 +632,9 @@ public class ROCTest extends BaseNd4jTest {
@Test
public void testRocAucExact() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocAucExact(Nd4jBackend backend) {
//Check the implementation vs. Scikitlearn
/*
@ -774,7 +797,9 @@ public class ROCTest extends BaseNd4jTest {
@Test
public void rocExactEdgeCaseReallocation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
//Set reallocation block size to say 20, but then evaluate a 100-length array
@ -786,7 +811,9 @@ public class ROCTest extends BaseNd4jTest {
@Test
public void testPrecisionRecallCurveGetPointMethods() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
double[] threshold = new double[101];
double[] precision = threshold;
double[] recall = new double[101];
@ -822,7 +849,9 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testPrecisionRecallCurveConfusion() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
//Sanity check: values calculated from the confusion matrix should match the PR curve values
for (boolean removeRedundantPts : new boolean[] {true, false}) {
@ -861,6 +890,8 @@ public class ROCTest extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocMerge(){
Nd4j.getRandom().setSeed(12345);
@ -905,6 +936,8 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocMultiMerge(){
Nd4j.getRandom().setSeed(12345);
@ -954,6 +987,8 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBinaryMerge(){
Nd4j.getRandom().setSeed(12345);
@ -999,6 +1034,8 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentationBinary(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345);
@ -1089,6 +1126,8 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345);

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
public class RegressionEvalTest extends BaseNd4jTest {
public class RegressionEvalTest extends BaseNd4jTestWithBackends {
public RegressionEvalTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -52,7 +51,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test()
public void testEvalParameters() {
public void testEvalParameters(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
int specCols = 5;
INDArray labels = Nd4j.ones(3);
@ -65,7 +64,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testPerfectPredictions() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPerfectPredictions(Nd4jBackend backend) {
int nCols = 5;
int nTestArrays = 100;
@ -92,7 +93,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testKnownValues() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testKnownValues(Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
RegressionEvaluation first = null;
@ -148,7 +151,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
@Test
public void testRegressionEvaluationMerging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvaluationMerging(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nRows = 20;
@ -189,7 +194,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testRegressionEvalPerOutputMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvalPerOutputMasking(Nd4jBackend backend) {
INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}});
@ -216,6 +223,8 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvalTimeSeriesSplit(){
INDArray out1 = Nd4j.rand(new int[]{3, 5, 20});
@ -238,7 +247,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testRegressionEval3d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -270,7 +281,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testRegressionEval4d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -302,7 +315,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testRegressionEval3dMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -361,7 +376,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testRegressionEval4dMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -22,10 +22,12 @@ package org.nd4j.evaluation;
import org.apache.commons.io.FileUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.io.ClassPathResource;
@ -34,11 +36,8 @@ import java.nio.charset.StandardCharsets;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestLegacyJsonLoading extends BaseNd4jTest {
public class TestLegacyJsonLoading extends BaseNd4jTestWithBackends {
public TestLegacyJsonLoading(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -46,7 +45,9 @@ public class TestLegacyJsonLoading extends BaseNd4jTest {
}
@Test
public void testEvalLegacyFormat() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalLegacyFormat(Nd4jBackend backend) throws Exception {
File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile();
String s = FileUtils.readFileToString(f, StandardCharsets.UTF_8);

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,17 +39,14 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@Slf4j
@RunWith(Parameterized.class)
public class AveragingTests extends BaseNd4jTest {
public class AveragingTests extends BaseNd4jTestWithBackends {
private final int THREADS = 16;
private final int LENGTH = 51200 * 4;
DataType initialType;
DataType initialType = Nd4j.dataType();
public AveragingTests(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach
public void setUp() {
@ -63,7 +61,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test
public void testSingleDeviceAveraging1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleDeviceAveraging1(Nd4jBackend backend) {
INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0);
INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0);
INDArray array3 = Nd4j.valueArrayOf(LENGTH, 3.0);
@ -110,7 +110,9 @@ public class AveragingTests extends BaseNd4jTest {
}
@Test
public void testSingleDeviceAveraging2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleDeviceAveraging2(Nd4jBackend backend) {
INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH);
List<INDArray> arrays = new ArrayList<>();
for (int i = 0; i < THREADS; i++)
@ -127,7 +129,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test
public void testAccumulation1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation1(Nd4jBackend backend) {
INDArray array1 = Nd4j.create(100).assign(1.0);
INDArray array2 = Nd4j.create(100).assign(2.0);
INDArray array3 = Nd4j.create(100).assign(3.0);
@ -140,7 +144,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test
public void testAccumulation2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation2(Nd4jBackend backend) {
INDArray array1 = Nd4j.create(100).assign(1.0);
INDArray array2 = Nd4j.create(100).assign(2.0);
INDArray array3 = Nd4j.create(100).assign(3.0);
@ -155,7 +161,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test
public void testAccumulation3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation3(Nd4jBackend backend) {
// we want to ensure that cuda backend is able to launch this op on cpu
Nd4j.getAffinityManager().allowCrossDeviceAccess(false);

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -34,15 +35,14 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@Slf4j
public class DataTypeTest extends BaseNd4jTest {
public DataTypeTest(Nd4jBackend backend) {
super(backend);
}
public class DataTypeTest extends BaseNd4jTestWithBackends {
@Test
public void testDataTypes() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDataTypes(Nd4jBackend backend) throws Exception {
for (val type : DataType.values()) {
if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type))
continue;

View File

@ -21,20 +21,17 @@
package org.nd4j.linalg;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.fail;
@RunWith(Parameterized.class)
public class InputValidationTests extends BaseNd4jTest {
public InputValidationTests(Nd4jBackend backend) {
super(backend);
}
public class InputValidationTests extends BaseNd4jTestWithBackends {
@Override
public char ordering() {
@ -45,7 +42,9 @@ public class InputValidationTests extends BaseNd4jTest {
///////////////////// Broadcast Tests ///////////////////////
@Test
public void testInvalidColVectorOp1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidColVectorOp1(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10);
INDArray col = Nd4j.create(5, 1);
try {
@ -57,7 +56,9 @@ public class InputValidationTests extends BaseNd4jTest {
}
@Test
public void testInvalidColVectorOp2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidColVectorOp2(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10);
INDArray col = Nd4j.create(5, 1);
try {
@ -69,7 +70,9 @@ public class InputValidationTests extends BaseNd4jTest {
}
@Test
public void testInvalidRowVectorOp1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidRowVectorOp1(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10);
INDArray row = Nd4j.create(1, 5);
try {
@ -81,7 +84,9 @@ public class InputValidationTests extends BaseNd4jTest {
}
@Test
public void testInvalidRowVectorOp2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidRowVectorOp2(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10);
INDArray row = Nd4j.create(1, 5);
try {

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.RandomUtils;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
@ -47,14 +48,13 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
@RunWith(Parameterized.class)
public class LoneTest extends BaseNd4jTest {
public LoneTest(Nd4jBackend backend) {
super(backend);
}
public class LoneTest extends BaseNd4jTestWithBackends {
@Test
public void testSoftmaxStability() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxStability(Nd4jBackend backend) {
INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose();
// System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1);
@ -68,7 +68,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void testFlattenedView() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFlattenedView(Nd4jBackend backend) {
int rows = 8;
int cols = 8;
int dim2 = 4;
@ -104,7 +106,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void testIndexingColVec() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingColVec(Nd4jBackend backend) {
int elements = 5;
INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements);
INDArray colVector = rowVector.transpose();
@ -123,7 +127,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void concatScalarVectorIssue() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void concatScalarVectorIssue(Nd4jBackend backend) {
//A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars
INDArray arr1 = Nd4j.create(1, 1);
INDArray arr2 = Nd4j.create(1, 8);
@ -133,7 +139,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void reshapeTensorMmul() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void reshapeTensorMmul(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2);
INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2);
int[][] axes = new int[2][];
@ -145,7 +153,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void maskWhenMerge() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void maskWhenMerge(Nd4jBackend backend) {
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
List<DataSet> dataSetList = new ArrayList<DataSet>();
@ -160,7 +170,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void testRelu() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRelu(Nd4jBackend backend) {
INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
INDArray b = Nd4j.getExecutioner().exec(new Tanh(aA));
@ -172,7 +184,7 @@ public class LoneTest extends BaseNd4jTest {
@Test
//broken at a threshold
public void testArgMax() {
public void testArgMax(Nd4jBackend backend) {
int max = 63;
INDArray A = Nd4j.linspace(1, max, max).reshape(1, max);
int currentArgMax = Nd4j.argMax(A).getInt(0);
@ -186,7 +198,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void testRPF() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRPF(Nd4jBackend backend) {
val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3);
log.info("--------");
@ -199,7 +213,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void testConcat3D_Vstack_C() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat3D_Vstack_C(Nd4jBackend backend) {
val shape = new long[]{1, 1000, 20};
List<INDArray> cArrays = new ArrayList<>();
@ -229,7 +245,9 @@ public class LoneTest extends BaseNd4jTest {
@Test
public void testGetRow1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRow1(Nd4jBackend backend) {
INDArray array = Nd4j.create(10000, 10000);
//Thread.sleep(10000);
@ -256,7 +274,7 @@ public class LoneTest extends BaseNd4jTest {
}
@Test()
public void checkIllegalElementOps() {
public void checkIllegalElementOps(Nd4jBackend backend) {
assertThrows(Exception.class,() -> {
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
INDArray B = A.dup().reshape(2, 2, 5);
@ -268,7 +286,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void checkSliceofSlice() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void checkSliceofSlice(Nd4jBackend backend) {
/*
Issue 1: Slice of slice with c order and f order views are not equal
@ -308,7 +328,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void checkWithReshape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void checkWithReshape(Nd4jBackend backend) {
INDArray arr = Nd4j.create(1, 3);
INDArray reshaped = arr.reshape('f', 3, 1);
for (int i=0;i<reshaped.length();i++) {

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -28,11 +30,8 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class MmulBug extends BaseNd4jTest {
public class MmulBug extends BaseNd4jTestWithBackends {
public MmulBug(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -40,7 +39,9 @@ public class MmulBug extends BaseNd4jTest {
}
@Test
public void simpleTest() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void simpleTest(Nd4jBackend backend) {
INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}});
m1 = m1.reshape(2, 2);

View File

@ -25,8 +25,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -58,19 +59,14 @@ import static org.junit.jupiter.api.Assertions.*;
*
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
@Slf4j
public class NDArrayTestsFortran extends BaseNd4jTest {
public NDArrayTestsFortran(Nd4jBackend backend) {
super(backend);
}
public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
public void testScalarOps() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarOps(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
assertEquals(27d, n.length(), 1e-1);
n.addi(Nd4j.scalar(1d));
@ -88,7 +84,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testColumnMmul() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumnMmul(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 10, 18, DataType.FLOAT).data();
INDArray x2 = Nd4j.create(data, new long[] {2, 3, 3});
data = Nd4j.linspace(1, 12, 9, DataType.FLOAT).data();
@ -119,7 +117,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testRowVectorGemm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowVectorGemm(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE);
INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE);
INDArray result = linspace.mmul(other);
@ -130,13 +130,17 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testRepmat() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRepmat(Nd4jBackend backend) {
INDArray rowVector = Nd4j.create(1, 4);
INDArray repmat = rowVector.repmat(4, 4);
assertTrue(Arrays.equals(new long[] {4, 16}, repmat.shape()));
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWrite() throws Exception {
INDArray write = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -152,6 +156,8 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWriteDouble() throws Exception {
INDArray write = Nd4j.linspace(1, 4, 4, DataType.FLOAT);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -168,18 +174,17 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiThreading() throws Exception {
ExecutorService ex = ExecutorServiceProvider.getExecutorService();
List<Future<?>> list = new ArrayList<>(100);
for (int i = 0; i < 100; i++) {
Future<?> future = ex.submit(new Runnable() {
@Override
public void run() {
Future<?> future = ex.submit(() -> {
INDArray dot = Nd4j.linspace(1, 8, 8, DataType.DOUBLE);
// System.out.println(Transforms.sigmoid(dot));
Transforms.sigmoid(dot);
}
});
list.add(future);
}
@ -191,7 +196,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testBroadcastingGenerated() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastingGenerated(Nd4jBackend backend) {
int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10);
List<List<Pair<INDArray, String>>> broadCastList = new ArrayList<>(broadcastShape.length);
for (int[] shape : broadcastShape) {
@ -216,7 +223,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testBroadCasting() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadCasting(Nd4jBackend backend) {
INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE);
INDArray ret = first.broadcast(3, 4);
INDArray testRet = Nd4j.create(new double[][] {{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}});
@ -229,14 +238,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testOneTensor() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneTensor(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1);
INDArray matrixToBroadcast = Nd4j.ones(1, 1);
assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr);
}
@Test
public void testSortWithIndicesDescending() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortWithIndicesDescending(Nd4jBackend backend) {
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
//indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, false);
@ -247,7 +260,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testSortDeadlock() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortDeadlock(Nd4jBackend backend) {
val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768);
val sorted = Nd4j.sort(toSort.dup(), 1, false);
@ -255,7 +270,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testSortWithIndices() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortWithIndices(Nd4jBackend backend) {
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
//indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, true);
@ -266,14 +283,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testNd4jSortScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jSortScalar(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1);
INDArray sorted = Nd4j.sort(linspace, 1, false);
// System.out.println(sorted);
}
@Test
public void testSwapAxesFortranOrder() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSwapAxesFortranOrder(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE);
for (int i = 0; i < n.slices(); i++) {
INDArray nSlice = n.slice(i);
@ -292,7 +313,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testDimShuffle() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDimShuffle(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false});
assertTrue(Arrays.equals(new long[] {2, 1, 2}, twoOneTwo.shape()));
@ -303,7 +326,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testGetVsGetScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetVsGetScalar(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
float element = a.getFloat(0, 1);
double element2 = a.getDouble(0, 1);
@ -316,7 +341,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testDivide() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDivide(Nd4jBackend backend) {
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
INDArray div = two.div(two);
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage());
@ -330,7 +357,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testSigmoid() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSigmoid(Nd4jBackend backend) {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
INDArray sigmoid = Transforms.sigmoid(n, false);
@ -339,7 +368,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testNeg() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNeg(Nd4jBackend backend) {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
INDArray neg = Transforms.neg(n);
@ -349,7 +380,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testCosineSim() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCosineSim(Nd4jBackend backend) {
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
double sim = Transforms.cosineSim(vec1, vec2);
@ -364,7 +397,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testExp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExp(Nd4jBackend backend) {
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f});
INDArray exped = Transforms.exp(n);
@ -374,7 +409,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalar(Nd4jBackend backend) {
INDArray a = Nd4j.scalar(1.0f);
assertEquals(true, a.isScalar());
@ -386,7 +423,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testWrap() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testWrap(Nd4jBackend backend) {
int[] shape = {2, 4};
INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]);
INDArray n = d;
@ -411,7 +450,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testGetRowFortran() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowFortran(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2});
INDArray column = Nd4j.create(new float[] {1, 3});
INDArray column2 = Nd4j.create(new float[] {2, 4});
@ -424,7 +465,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testGetColumnFortran() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnFortran(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
INDArray column = Nd4j.create(new double[] {1, 2});
INDArray column2 = Nd4j.create(new double[] {3, 4});
@ -438,7 +481,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testGetColumns() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumns(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE);
// log.info("Original: {}", matrix);
INDArray matrixGet = matrix.getColumns(1, 2);
@ -452,7 +497,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testVectorInit() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorInit(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data();
INDArray arr = Nd4j.create(data, new long[] {1, 4});
assertEquals(true, arr.isRowVector());
@ -465,7 +512,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testAssignOffset() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssignOffset(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(5, 5);
INDArray row = arr.slice(1);
row.assign(1);
@ -473,7 +522,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testColumns() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumns(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE);
INDArray column = Nd4j.create(new double[] {1, 2, 3});
arr.putColumn(0, column);
@ -511,7 +562,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testPutRow() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRow(Nd4jBackend backend) {
INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray n = d.dup();
@ -570,7 +623,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testInplaceTranspose() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInplaceTranspose(Nd4jBackend backend) {
INDArray test = Nd4j.rand(3, 4);
INDArray orig = test.dup();
INDArray transposei = test.transposei();
@ -585,7 +640,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testMmulF() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulF(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data();
INDArray n = Nd4j.create(data, new long[] {1, 10});
@ -603,7 +660,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testRowsColumns() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowsColumns(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data();
INDArray rows = Nd4j.create(data, new long[] {2, 3});
assertEquals(2, rows.rows());
@ -619,7 +678,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testTranspose() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTranspose(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4});
INDArray transpose = n.transpose();
assertEquals(n.length(), transpose.length());
@ -647,7 +708,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testAddMatrix() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddMatrix(Nd4jBackend backend) {
INDArray five = Nd4j.ones(5);
five.addi(five.dup());
INDArray twos = Nd4j.valueArrayOf(5, 2);
@ -658,7 +721,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testMMul() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMMul(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
@ -669,7 +734,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testPutSlice() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutSlice(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3);
INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3);
Nd4j.exec(new PrintVariable(newSlice));
@ -680,7 +747,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testRowVectorMultipleIndices() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
linear.putScalar(new long[] {0, 1}, 1);
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage());
@ -689,7 +758,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testDim1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDim1(Nd4jBackend backend) {
INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1);
INDArray same = sum.dup();
assertEquals(same.sum(1), sum.reshape(2));
@ -697,7 +768,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testEps() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEps(Nd4jBackend backend) {
val ones = Nd4j.ones(5);
val res = Nd4j.createUninitialized(DataType.BOOL, 5);
assertTrue(Nd4j.getExecutioner().exec(new Eps(ones, ones, res)).all());
@ -705,7 +778,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testLogDouble() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogDouble(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE);
INDArray log = Transforms.log(linspace);
INDArray assertion = Nd4j.create(new double[] {0, 0.6931471805599453, 1.0986122886681098, 1.3862943611198906, 1.6094379124341005, 1.791759469228055});
@ -713,28 +788,36 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testVectorSum() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum(Nd4jBackend backend) {
INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
}
@Test
public void testVectorSum2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum2(Nd4jBackend backend) {
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
}
@Test
public void testVectorSum3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum3(Nd4jBackend backend) {
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals(lin, lin2);
}
@Test
public void testSmallSum() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSmallSum(Nd4jBackend backend) {
INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007});
base.addi(1e-12);
INDArray assertion = Nd4j.create(new double[] {5.84333433, 3.054001});
@ -745,7 +828,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testPermute() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4});
INDArray transpose = n.transpose();
INDArray permute = n.permute(1, 0);
@ -774,7 +859,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testAppendBias() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAppendBias(Nd4jBackend backend) {
INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray test = Nd4j.appendBias(rand);
INDArray assertion = Nd4j.toFlattened(rand, Nd4j.scalar(DataType.DOUBLE, 1.0)).reshape(-1, 1);
@ -782,7 +869,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testRand() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRand(Nd4jBackend backend) {
INDArray rand = Nd4j.randn(5, 5);
Nd4j.getDistributions().createUniform(0.4, 4).sample(5);
Nd4j.getDistributions().createNormal(1, 5).sample(10);
@ -794,7 +883,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testIdentity() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIdentity(Nd4jBackend backend) {
INDArray eye = Nd4j.eye(5);
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
eye = Nd4j.eye(5);
@ -805,7 +896,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testColumnVectorOpsFortran() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumnVectorOpsFortran(Nd4jBackend backend) {
INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1});
twoByTwo.addiColumnVector(toAdd);
@ -816,7 +909,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testRSubi() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRSubi(Nd4jBackend backend) {
INDArray n2 = Nd4j.ones(2);
INDArray n2Assertion = Nd4j.zeros(2);
INDArray nRsubi = n2.rsubi(1);
@ -826,7 +921,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testAssign() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
vector.assign(1);
assertEquals(Nd4j.ones(5).castTo(DataType.DOUBLE), vector);
@ -843,7 +940,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testAddScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddScalar(Nd4jBackend backend) {
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
INDArray rdiv = div.add(1);
INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 5.0);
@ -851,7 +950,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testRdivScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRdivScalar(Nd4jBackend backend) {
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
INDArray rdiv = div.rdiv(1);
INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 0.25);
@ -859,7 +960,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testRDivi() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRDivi(Nd4jBackend backend) {
INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0);
INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5);
INDArray nRsubi = n2.rdivi(2);
@ -869,7 +972,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testNumVectorsAlongDimension() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNumVectorsAlongDimension(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2);
assertEquals(12, arr.vectorsAlongDimension(2));
}
@ -877,7 +982,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testBroadCast() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadCast(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
INDArray broadCasted = n.broadcast(5, 4);
for (int i = 0; i < broadCasted.rows(); i++) {
@ -899,7 +1006,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testMatrix() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrix(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2});
INDArray row = arr.getRow(0);
@ -909,7 +1018,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testPutRowGetRowOrdering() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowGetRowOrdering(Nd4jBackend backend) {
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray put = Nd4j.create(new double[] {5, 6});
row1.putRow(1, put);
@ -931,7 +1042,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testSumWithRow1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSumWithRow1(Nd4jBackend backend) {
//Works:
INDArray array2d = Nd4j.ones(1, 10);
array2d.sum(0); //OK
@ -962,7 +1075,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testSumWithRow2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSumWithRow2(Nd4jBackend backend) {
//All sums in this method execute without exceptions.
INDArray array3d = Nd4j.ones(2, 10, 10);
array3d.sum(0);
@ -985,7 +1100,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testPutRowFortran() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowFortran(Nd4jBackend backend) {
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE);
INDArray put = Nd4j.create(new double[] {5, 6});
row1.putRow(1, put);
@ -998,7 +1115,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testElementWiseOps() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testElementWiseOps(Nd4jBackend backend) {
INDArray n1 = Nd4j.scalar(1);
INDArray n2 = Nd4j.scalar(2);
INDArray nClone = n1.add(n2);
@ -1021,7 +1140,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testRollAxis() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRollAxis(Nd4jBackend backend) {
INDArray toRoll = Nd4j.ones(3, 4, 5, 6);
assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape());
val shape = Nd4j.rollAxis(toRoll, 3).shape();
@ -1030,7 +1151,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
@Disabled
public void testTensorDot() {
public void testTensorDot(Nd4jBackend backend) {
INDArray oneThroughSixty = Nd4j.arange(60).reshape('f', 3, 4, 5).castTo(DataType.DOUBLE);
INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape('f', 4, 3, 2).castTo(DataType.DOUBLE);
INDArray result = Nd4j.tensorMmul(oneThroughSixty, oneThroughTwentyFour, new int[][] {{1, 0}, {0, 1}});
@ -1043,7 +1164,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testNegativeShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeShape(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
INDArray reshaped = linspace.reshape(-1, 2);
assertArrayEquals(new long[] {2, 2}, reshaped.shape());
@ -1055,7 +1178,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testGetColumnGetRow() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnGetRow(Nd4jBackend backend) {
INDArray row = Nd4j.ones(1, 5);
for (int i = 0; i < 5; i++) {
INDArray col = row.getColumn(i);
@ -1070,7 +1195,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testDupAndDupWithOrder() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDupAndDupWithOrder(Nd4jBackend backend) {
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
int count = 0;
for (Pair<INDArray, String> pair : testInputs) {
@ -1092,7 +1219,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testToOffsetZeroCopy() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToOffsetZeroCopy(Nd4jBackend backend) {
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
int cnt = 0;

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,18 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class Nd4jTestsComparisonC extends BaseNd4jTest {
public class Nd4jTestsComparisonC extends BaseNd4jTestWithBackends {
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonC.class);
public static final int SEED = 123;
DataType initialType;
DataType initialType = Nd4j.dataType();
public Nd4jTestsComparisonC(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach
@ -73,7 +70,9 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest {
@Test
public void testGemmWithOpsCommonsMath() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);

View File

@ -25,8 +25,9 @@ import org.apache.commons.math3.linear.RealMatrix;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -43,18 +44,14 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
@RunWith(Parameterized.class)
public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonFortran.class);
public static final int SEED = 123;
DataType initialType;
DataType initialType = Nd4j.dataType();
public Nd4jTestsComparisonFortran(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach
@ -75,7 +72,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testCrash() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCrash(Nd4jBackend backend) {
INDArray array3d = Nd4j.ones(1, 10, 10);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 1);
@ -85,7 +84,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testMmulWithOpsCommonsMath() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
@ -100,7 +101,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testGemmWithOpsCommonsMath() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
@ -156,7 +159,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testGemvApacheCommons() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemvApacheCommons(Nd4jBackend backend) {
int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8};
int[] colsArr = new int[] {2, 1, 10, 2, 1, 10};
@ -211,7 +216,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testAddSubtractWithOpsCommonsMath() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddSubtractWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
for (int i = 0; i < first.size(); i++) {
@ -229,7 +236,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testMulDivOnCheckUtilMatrices() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMulDivOnCheckUtilMatrices(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
for (int i = 0; i < first.size(); i++) {

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -36,18 +37,15 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j
@RunWith(Parameterized.class)
public class Nd4jTestsF extends BaseNd4jTest {
DataType initialType;
public class Nd4jTestsF extends BaseNd4jTestWithBackends {
public Nd4jTestsF(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
DataType initialType = Nd4j.dataType();
@Test
public void testConcat3D_Vstack_F() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat3D_Vstack_F(Nd4jBackend backend) {
//Nd4j.getExecutioner().enableVerboseMode(true);
//Nd4j.getExecutioner().enableDebugMode(true);
@ -79,7 +77,9 @@ public class Nd4jTestsF extends BaseNd4jTest {
@Test
public void testSlice_1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice_1(Nd4jBackend backend) {
val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1);
val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1});
val exp1 = Nd4j.create(new double[]{2, 4}, new int[] {2, 1});

View File

@ -22,8 +22,9 @@ package org.nd4j.linalg;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -34,15 +35,13 @@ import java.util.*;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.*;
@RunWith(Parameterized.class)
public class ShufflesTests extends BaseNd4jTest {
public ShufflesTests(Nd4jBackend backend) {
super(backend);
}
public class ShufflesTests extends BaseNd4jTestWithBackends {
@Test
public void testSimpleShuffle1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle1(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(10, 10);
for (int x = 0; x < 10; x++) {
array.getRow(x).assign(x);
@ -64,7 +63,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testSimpleShuffle2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle2(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(10, 10);
for (int x = 0; x < 10; x++) {
array.getColumn(x).assign(x);
@ -79,7 +80,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testSimpleShuffle3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle3(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(11, 10);
for (int x = 0; x < 11; x++) {
array.getRow(x).assign(x);
@ -95,7 +98,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testSymmetricShuffle1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle1(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10);
INDArray labels = Nd4j.zeros(10, 3);
for (int x = 0; x < 10; x++) {
@ -133,7 +138,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testSymmetricShuffle2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle2(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10, 20);
INDArray labels = Nd4j.zeros(10, 10, 3);
@ -171,7 +178,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testSymmetricShuffle3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle3(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10, 20);
INDArray featuresMask = Nd4j.zeros(10, 20);
INDArray labels = Nd4j.zeros(10, 10, 3);
@ -236,7 +245,9 @@ public class ShufflesTests extends BaseNd4jTest {
* @throws Exception
*/
@Test
public void testHalfVectors1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testHalfVectors1(Nd4jBackend backend) {
int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20);
int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20);
@ -257,7 +268,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testInterleavedVector1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterleavedVector1(Nd4jBackend backend) {
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20);
int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20);
@ -278,7 +291,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testInterleavedVector3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterleavedVector3(Nd4jBackend backend) {
for (int e = 0; e < 1000; e++) {
int length = e + 256; //RandomUtils.nextInt(121, 2073);
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(System.currentTimeMillis()), length);

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.eigen.Eigen;
@ -35,16 +36,11 @@ import org.nd4j.common.util.ArrayUtil;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@Slf4j
public class TestEigen extends BaseNd4jTest {
public class TestEigen extends BaseNd4jTestWithBackends {
protected DataType initialType;
public TestEigen(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
protected DataType initialType = Nd4j.dataType();
@BeforeEach
public void before() {
@ -59,7 +55,9 @@ public class TestEigen extends BaseNd4jTest {
// test of functions added by Luke Czapla
// Compares solution of A x = L x to solution to A x = L B x when it is simple
@Test
public void test2Syev() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2Syev(Nd4jBackend backend) {
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
Nd4j.setDefaultDataTypes(dt, dt);
@ -78,7 +76,9 @@ public class TestEigen extends BaseNd4jTest {
}
@Test
public void testSyev() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSyev(Nd4jBackend backend) {
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
//log.info("Datatype: {}", dt);
Nd4j.setDefaultDataTypes(dt, dt);

View File

@ -24,23 +24,23 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.util.ArrayUtil;
@RunWith(Parameterized.class)
@Slf4j
public class ToStringTest extends BaseNd4jTest {
public ToStringTest(Nd4jBackend backend) {
super(backend);
}
public class ToStringTest extends BaseNd4jTestWithBackends {
@Test
public void testToString() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToString(Nd4jBackend backend) throws Exception {
assertEquals("[ 1, 2, 3]",
Nd4j.createFromArray(1, 2, 3).toString());
@ -58,6 +58,8 @@ public class ToStringTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToStringScalars(){
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.activations;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.impl.ActivationCube;
import org.nd4j.linalg.activations.impl.ActivationELU;
import org.nd4j.linalg.activations.impl.ActivationGELU;
@ -55,12 +56,9 @@ import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestActivation extends BaseNd4jTest {
public TestActivation(Nd4jBackend backend) {
super(backend);
}
public class TestActivation extends BaseNd4jTestWithBackends {
@Override
public char ordering() {
@ -79,7 +77,9 @@ public class TestActivation extends BaseNd4jTest {
}
@Test
public void testRelu(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRelu(Nd4jBackend backend){
Double[] max = {null, 6.0, 2.5, 5.0};
Double[] threshold = {0.0, 0.0, 0.75, 0.2};
@ -131,7 +131,9 @@ public class TestActivation extends BaseNd4jTest {
}
@Test
public void testJson() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJson(Nd4jBackend backend) throws Exception {
IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25),
new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(),

View File

@ -20,21 +20,20 @@
package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.factory.Environment;
import org.nd4j.linalg.factory.Nd4j;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertFalse;
public class TestBackend extends BaseNd4jTest {
public class TestBackend extends BaseNd4jTestWithBackends {
public TestBackend(Nd4jBackend backend) {
super(backend);
}
@Test
public void TestBuildInfo(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBuildInfo(Nd4jBackend backend){
System.out.println("Backend build info: " + backend.buildInfo());
}
}

View File

@ -20,18 +20,17 @@
package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Environment;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertFalse;
public class TestEnvironment extends BaseNd4jTest {
public class TestEnvironment extends BaseNd4jTestWithBackends {
public TestEnvironment(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -39,7 +38,9 @@ public class TestEnvironment extends BaseNd4jTest {
}
@Test
public void testEnvironment(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEnvironment(Nd4jBackend backend){
Environment e = Nd4j.getEnvironment();
System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion());
System.out.println("CPU: " + e.isCPU());

View File

@ -26,7 +26,9 @@ import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -40,16 +42,12 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class TestNDArrayCreation extends BaseNd4jTest {
public TestNDArrayCreation(Nd4jBackend backend) {
super(backend);
}
public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
@Test
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public void testBufferCreation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBufferCreation(Nd4jBackend backend) {
DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2});
Pointer pointer = dataBuffer.pointer();
FloatPointer floatPointer = new FloatPointer(pointer);
@ -69,6 +67,8 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test
@Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateNpy() throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile());
assertEquals(2, arrCreate.size(0));
@ -82,7 +82,9 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test
@Disabled
public void testCreateNpz() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateNpz(Nd4jBackend backend) throws Exception {
Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile());
assertEquals(true, map.containsKey("x"));
assertEquals(true, map.containsKey("y"));
@ -100,8 +102,7 @@ public class TestNDArrayCreation extends BaseNd4jTest {
}
@Test
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public void testCreateNpy3() throws Exception {
public void testCreateNpy3(Nd4jBackend backend) throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
assertEquals(8, arrCreate.length());
assertEquals(3, arrCreate.rank());
@ -113,7 +114,7 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test
@Disabled // this is endless test
public void testEndlessAllocation() {
public void testEndlessAllocation(Nd4jBackend backend) {
Nd4j.getEnvironment().setMaxSpecialMemory(1);
while (true) {
val arr = Nd4j.createUninitialized(DataType.FLOAT, 100000000);

View File

@ -21,24 +21,23 @@
package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
public class TestNDArrayCreationUtil extends BaseNd4jTest {
public class TestNDArrayCreationUtil extends BaseNd4jTestWithBackends {
public TestNDArrayCreationUtil(Nd4jBackend backend) {
super(backend);
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapes() {
long[] shape2d = {2, 3};

View File

@ -21,20 +21,21 @@
package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
public class TestNamespaces extends BaseNd4jTest {
public class TestNamespaces extends BaseNd4jTestWithBackends {
public TestNamespaces(Nd4jBackend backend) {
super(backend);
}
@Test
public void testBitwiseSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBitwiseSimple(Nd4jBackend backend){
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
@ -50,7 +51,9 @@ public class TestNamespaces extends BaseNd4jTest {
}
@Test
public void testMathSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMathSimple(Nd4jBackend backend) {
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1);
INDArray abs = Nd4j.math.abs(x);
// System.out.println(x);
@ -65,7 +68,9 @@ public class TestNamespaces extends BaseNd4jTest {
}
@Test
public void testRandomSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomSimple(Nd4jBackend backend){
INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10);
// System.out.println(normal);
INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10);
@ -73,7 +78,9 @@ public class TestNamespaces extends BaseNd4jTest {
}
@Test
public void testNeuralNetworkSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNeuralNetworkSimple(Nd4jBackend backend){
INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10));
// System.out.println(out);
INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -31,15 +32,14 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class LapackTest extends BaseNd4jTest {
public LapackTest(Nd4jBackend backend) {
super(backend);
}
public class LapackTest extends BaseNd4jTestWithBackends {
@Test
public void testQRSquare() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testQRSquare(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9});
A = A.reshape('c', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -57,7 +57,9 @@ public class LapackTest extends BaseNd4jTest {
}
@Test
public void testQRRect() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testQRRect(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
A = A.reshape('f', 4, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -75,7 +77,9 @@ public class LapackTest extends BaseNd4jTest {
}
@Test
public void testCholeskyL() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCholeskyL(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,});
A = A.reshape('c', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -92,7 +96,9 @@ public class LapackTest extends BaseNd4jTest {
}
@Test
public void testCholeskyU() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCholeskyU(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,});
A = A.reshape('f', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape());

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.blas;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -35,14 +36,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class Level1Test extends BaseNd4jTest {
public Level1Test(Nd4jBackend backend) {
super(backend);
}
public class Level1Test extends BaseNd4jTestWithBackends {
@Test
public void testDot() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDot(Nd4jBackend backend) {
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1);
@ -55,7 +55,9 @@ public class Level1Test extends BaseNd4jTest {
}
@Test
public void testAxpy() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAxpy(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray row = matrix.getRow(1);
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
@ -64,7 +66,9 @@ public class Level1Test extends BaseNd4jTest {
}
@Test
public void testAxpy2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAxpy2(Nd4jBackend backend) {
val rowX = Nd4j.create(new double[]{1, 2, 3, 4});
val rowY = Nd4j.create(new double[]{1, 2, 3, 4});
val exp = Nd4j.create(new double[]{3, 6, 9, 12});

View File

@ -21,23 +21,23 @@
package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class Level2Test extends BaseNd4jTest {
public Level2Test(Nd4jBackend backend) {
super(backend);
}
public class Level2Test extends BaseNd4jTestWithBackends {
@Test
public void testGemv1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv1(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -51,7 +51,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv2(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -65,7 +67,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv3(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -79,7 +83,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv4() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv4(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -93,7 +99,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv5() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv5(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -109,7 +117,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv6() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv6(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -125,7 +135,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv7() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv7(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);

View File

@ -21,23 +21,23 @@
package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class Level3Test extends BaseNd4jTest {
public Level3Test(Nd4jBackend backend) {
super(backend);
}
public class Level3Test extends BaseNd4jTestWithBackends {
@Test
public void testGemm1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm1(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -47,7 +47,9 @@ public class Level3Test extends BaseNd4jTest {
}
@Test
public void testGemm2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm2(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -57,7 +59,9 @@ public class Level3Test extends BaseNd4jTest {
}
@Test
public void testGemm3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm3(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
@ -75,7 +79,9 @@ public class Level3Test extends BaseNd4jTest {
}
@Test
public void testGemm4() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm4(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);
@ -92,7 +98,9 @@ public class Level3Test extends BaseNd4jTest {
}
@Test
public void testGemm5() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm5(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
@ -106,7 +114,9 @@ public class Level3Test extends BaseNd4jTest {
}
@Test
public void testGemm6() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm6(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.blas.params;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,16 +34,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class ParamsTestsF extends BaseNd4jTest {
public ParamsTestsF(Nd4jBackend backend) {
super(backend);
}
public class ParamsTestsF extends BaseNd4jTestWithBackends {
@Test
public void testGemm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm (Nd4jBackend backend) {
INDArray a = Nd4j.create(2, 2);
INDArray b = Nd4j.create(2, 3);
INDArray c = Nd4j.create(2, 3);

View File

@ -25,9 +25,10 @@ import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -45,16 +46,15 @@ import java.nio.ByteOrder;
import static org.junit.jupiter.api.Assertions.*;
@Slf4j
@RunWith(Parameterized.class)
public class DataBufferTests extends BaseNd4jTest {
public DataBufferTests(Nd4jBackend backend) {
super(backend);
}
public class DataBufferTests extends BaseNd4jTestWithBackends {
@Test
@Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657")
public void testNoArgCreateBufferFromArray() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoArgCreateBufferFromArray(Nd4jBackend backend) {
//Tests here:
//1. Create from JVM array
@ -280,7 +280,9 @@ public class DataBufferTests extends BaseNd4jTest {
@Test
public void testCreateTypedBuffer() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateTypedBuffer(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
@ -350,7 +352,9 @@ public class DataBufferTests extends BaseNd4jTest {
}
@Test
public void testAsBytes() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAsBytes(Nd4jBackend backend) {
INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1);
for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16,
@ -405,6 +409,8 @@ public class DataBufferTests extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEnsureLocation(){
//https://github.com/eclipse/deeplearning4j/issues/8783
Nd4j.create(1);

View File

@ -23,9 +23,10 @@ package org.nd4j.linalg.api.buffer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
@ -33,13 +34,10 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertThrows;
@RunWith(Parameterized.class)
public class DataTypeValidationTests extends BaseNd4jTest {
DataType initialType;
public DataTypeValidationTests(Nd4jBackend backend) {
super(backend);
}
public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
@BeforeEach
public void setUp() {
@ -48,7 +46,7 @@ public class DataTypeValidationTests extends BaseNd4jTest {
}
@AfterEach
public void shutUp() {
public void reset() {
Nd4j.setDataType(initialType);
}
@ -73,7 +71,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level1 blas
*/
@Test()
public void testBlasValidation1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
INDArray x = Nd4j.create(10);
@ -90,7 +90,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level2 blas
*/
@Test()
public void testBlasValidation2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation2(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> {
INDArray a = Nd4j.create(100, 10);
INDArray x = Nd4j.create(100);
@ -108,7 +110,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level3 blas
*/
@Test()
public void testBlasValidation3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation3(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
INDArray x = Nd4j.create(100, 100);

View File

@ -26,9 +26,10 @@ import org.bytedeco.javacpp.indexer.Indexer;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -54,34 +55,31 @@ import static org.junit.jupiter.api.Assertions.*;
*
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public class DoubleDataBufferTest extends BaseNd4jTest {
public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
DataType initialType;
public DoubleDataBufferTest(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
DataType initialType = Nd4j.dataType();
@BeforeEach
public void before() {
public void before(Nd4jBackend backend) {
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
}
@AfterEach
public void after() {
public void after(Nd4jBackend backend) {
DataTypeUtil.setDTypeForContext(initialType);
}
@Test
public void testPointerCreation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointerCreation(Nd4jBackend backend) {
DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4);
Indexer indexer = DoubleIndexer.create(floatPointer);
DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.DOUBLE, 4, indexer);
@ -90,7 +88,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
public void testGetSet() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetSet(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
double[] d2 = d.asDouble();
@ -101,6 +101,8 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization2() throws Exception {
INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10),
// Nd4j.ones(5,10).getRow(2)
@ -129,6 +131,8 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile();
DataBuffer buf = Nd4j.createBuffer(5);
@ -151,7 +155,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test
public void testDup() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDup(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
DataBuffer d2 = d.dup();
@ -161,7 +167,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test
public void testPut() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPut(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
d.put(0, 0.0);
@ -172,7 +180,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test
public void testGetRange() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
double[] get = buffer.getDoublesAt(0, 3);
double[] data = new double[] {1, 2, 3};
@ -187,7 +197,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
public void testGetOffsetRange() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetOffsetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
double[] get = buffer.getDoublesAt(1, 3);
double[] data = new double[] {2, 3, 4};
@ -202,7 +214,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
public void testAssign() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
DataBuffer one = Nd4j.createBuffer(new double[] {1});
DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3});
@ -213,7 +227,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test
public void testOffset() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffset(Nd4jBackend backend) {
DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2);
assertEquals(2, create.length());
assertEquals(0, create.offset());
@ -223,7 +239,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
public void testReallocation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity());
double[] old = buffer.asDouble();
@ -233,7 +251,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
public void testReallocationWorkspace() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
@ -250,6 +270,8 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddressPointer(){
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
return;

View File

@ -27,7 +27,9 @@ import org.bytedeco.javacpp.indexer.Indexer;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -54,14 +56,9 @@ import static org.junit.jupiter.api.Assertions.*;
* @author Adam Gibson
*/
@Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public class FloatDataBufferTest extends BaseNd4jTest {
public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
DataType initialType;
public FloatDataBufferTest(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
DataType initialType = Nd4j.dataType();
@BeforeEach
public void before() {
@ -76,7 +73,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test
public void testPointerCreation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointerCreation(Nd4jBackend backend) {
FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4);
Indexer indexer = FloatIndexer.create(floatPointer);
DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.FLOAT, 4, indexer);
@ -85,7 +84,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testGetSet() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetSet(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
float[] d2 = d.asFloat();
@ -96,7 +97,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test
public void testSerialization(@TempDir Path tempDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization(@TempDir Path tempDir,Nd4jBackend backend) throws Exception {
File dir = tempDir.toFile();
DataBuffer buf = Nd4j.createBuffer(5);
String fileName = "buf.ser";
@ -117,7 +120,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testDup() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDup(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
DataBuffer d2 = d.dup();
@ -125,7 +130,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testToNio() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToNio(Nd4jBackend backend) {
DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT);
assertEquals(4, buff.length());
if (buff.allocationMode() == DataBuffer.AllocationMode.HEAP)
@ -137,7 +144,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testPut() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPut(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
d.put(0, 0.0);
@ -148,7 +157,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test
public void testGetRange() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(0, 3);
float[] data = new float[] {1, 2, 3};
@ -164,7 +175,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test
public void testGetOffsetRange() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetOffsetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(1, 3);
float[] data = new float[] {2, 3, 4};
@ -181,7 +194,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test
public void testAsBytes() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAsBytes(Nd4jBackend backend) {
INDArray arr = Nd4j.create(5);
byte[] d = arr.data().asBytes();
assertEquals(4 * 5, d.length,getFailureMessage());
@ -191,7 +206,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testAssign() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
DataBuffer one = Nd4j.createBuffer(new double[] {1});
DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3});
@ -201,7 +218,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testReadWrite() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWrite(Nd4jBackend backend) throws Exception {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
@ -215,7 +234,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testOffset() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffset(Nd4jBackend backend) {
DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2);
assertEquals(2, create.length());
assertEquals(0, create.offset());
@ -225,7 +246,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testReallocation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity());
float[] old = buffer.asFloat();
@ -236,7 +259,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testReallocationWorkspace() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
@ -253,7 +278,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testAddressPointer(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddressPointer(Nd4jBackend backend){
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
return;
}

View File

@ -23,7 +23,9 @@ package org.nd4j.linalg.api.buffer;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
@ -37,13 +39,12 @@ import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.*;
public class IntDataBufferTests extends BaseNd4jTest {
public class IntDataBufferTests extends BaseNd4jTestWithBackends {
public IntDataBufferTests(Nd4jBackend backend) {
super(backend);
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasicSerde1() throws Exception {
@ -82,7 +83,9 @@ public class IntDataBufferTests extends BaseNd4jTest {
*/
@Test
public void testReallocation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity());
buffer.reallocate(6);
@ -94,7 +97,9 @@ public class IntDataBufferTests extends BaseNd4jTest {
}
@Test
public void testReallocationWorkspace() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -37,17 +38,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class IndexingTests extends BaseNd4jTest {
public class IndexingTests extends BaseNd4jTestWithBackends {
public IndexingTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testINDArrayIndexingEqualToRank() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingEqualToRank(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{
{0,1,2},
@ -62,7 +61,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test
public void testINDArrayIndexingLessThanRankSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingLessThanRankSimple(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{
{0},
@ -76,7 +77,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test
public void testINDArrayIndexingLessThanRankFourDimension() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingLessThanRankFourDimension(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{
{0},{1}
@ -89,7 +92,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testPutSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutSimple(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2);
INDArray indexes = Nd4j.create(new double[][]{
{0},{1}
@ -101,7 +106,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testGetScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetScalar(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(NDArrayIndex.point(1));
assertTrue(d.isScalar());
@ -110,14 +117,18 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testNewAxis() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis(Nd4jBackend backend) {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
// System.out.println(view);
}
@Test
public void testVectorIndexing() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE);
int[] index = new int[] {5, 8, 9};
INDArray columnsTest = x.getColumns(index);
@ -129,7 +140,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testGetRowsColumnsMatrix() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowsColumnsMatrix(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6);
INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}});
@ -147,7 +160,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test
public void testSlicing() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlicing(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14});
INDArray slice1Test = arange.slice(1);
@ -155,7 +170,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testArangeMul() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArangeMul(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = NDArrayIndex.interval(0, 2);
INDArray get = arange.get(index, index);
@ -167,7 +184,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testGetIndicesVector() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVector(Nd4jBackend backend) {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3});
INDArray result = line.get(NDArrayIndex.point(0), NDArrayIndex.interval(1, 3));
@ -175,7 +194,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testGetIndicesVectorView() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVectorView(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5);
INDArray column = matrix.getColumn(0).reshape(1,5);
INDArray test = Nd4j.create(new double[] {6, 11});
@ -193,7 +214,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void test2dGetPoint(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2dGetPoint(Nd4jBackend backend){
INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4);
for( int i=0; i<3; i++ ){
INDArray exp = Nd4j.create(new double[]{i*4+1, i*4+2, i*4+3, i*4+4});

View File

@ -21,10 +21,11 @@
package org.nd4j.linalg.api.indexing;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -49,16 +50,15 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class IndexingTestsC extends BaseNd4jTest {
public class IndexingTestsC extends BaseNd4jTestWithBackends {
public IndexingTestsC(Nd4jBackend backend) {
super(backend);
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeBounds() {
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
@ -71,6 +71,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis() {
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
@ -80,6 +82,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void broadcastBug() {
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
@ -91,6 +95,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIntervalsIn3D() {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
@ -100,6 +106,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSmallInterval() {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
@ -109,6 +117,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxisAndInterval() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
@ -118,6 +128,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxisInMiddle() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
@ -127,6 +139,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxis() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray get = arr.get(newAxis(), all(), point(1));
@ -137,6 +151,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingWithMmul() {
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
@ -148,6 +164,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointPointInterval() {
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
@ -157,6 +175,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIntervalLowerBound() {
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
@ -168,6 +188,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetPointRowVector() {
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
@ -178,6 +200,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedIndexVector() {
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
@ -195,6 +219,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowIndexing() {
INDArray arr = Nd4j.ones(1, 10);
INDArray row = Nd4j.create(1, 10);
@ -205,6 +231,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing2() {
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
INDArray assertion = Nd4j.create(new double[] {2, 4});
@ -220,6 +248,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffsetsC() {
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
@ -236,6 +266,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexFor() {
long[] shape = {1, 2};
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
@ -245,6 +277,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetScalar() {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(point(1));
@ -254,6 +288,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing() {
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
@ -262,6 +298,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeIndices() {
INDArray test = Nd4j.create(10, 10, 10);
test.putScalar(new int[] {0, 0, -1}, 1.0);
@ -269,6 +307,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndices2d() {
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
INDArray firstRow = twoByTwo.getRow(0);
@ -287,6 +327,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRow() {
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
@ -304,6 +346,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowEdgeCase() {
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
INDArray get = rowVec.getRow(0); //Returning shape [1,1]
@ -313,6 +357,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnEdgeCase() {
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray get = colVec.getColumn(0); //Returning shape [1,1]
@ -322,6 +368,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatColumns() {
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
@ -331,6 +379,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVector() {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3});
@ -339,6 +389,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArangeMul() {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = interval(0, 2);
@ -350,6 +402,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingThorough(){
long[] fullShape = {3,4,5,6,7};
@ -550,6 +604,8 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void debugging(){
long[] inShape = {3,4};
INDArrayIndex[] indexes = new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(1, 2, 4)};

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.resolve;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -36,15 +37,14 @@ import static org.junit.jupiter.api.Assertions.*;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class NDArrayIndexResolveTests extends BaseNd4jTest {
public NDArrayIndexResolveTests(Nd4jBackend backend) {
super(backend);
}
public class NDArrayIndexResolveTests extends BaseNd4jTestWithBackends {
@Test
public void testResolvePoint() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResolvePoint(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
INDArrayIndex[] test = NDArrayIndex.resolve(arr.shape(), NDArrayIndex.point(1));
INDArrayIndex[] assertion = {NDArrayIndex.point(1), NDArrayIndex.all()};
@ -59,6 +59,8 @@ public class NDArrayIndexResolveTests extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResolvePointVector() {
INDArray arr = Nd4j.linspace(1, 4, 4);
INDArrayIndex[] getPoint = {NDArrayIndex.point(1)};

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.shape;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.Indices;
@ -34,19 +35,15 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class IndexShapeTests extends BaseNd4jTest {
public IndexShapeTests(Nd4jBackend backend) {
super(backend);
}
public class IndexShapeTests extends BaseNd4jTestWithBackends {
private int[] shape = {1, 1, 2, 1, 3, 4, 5, 1};
@Test
public void testSinglePoint() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSinglePoint(Nd4jBackend backend) {
/*
Assumes all indexes are filled out.
Test simple general point case
@ -77,7 +74,9 @@ public class IndexShapeTests extends BaseNd4jTest {
}
@Test
public void testInterval() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterval(Nd4jBackend backend) {
int[] basicAssertion = {1, 1, 1, 1, 3, 1, 2, 1};
INDArrayIndex[] basicTest = {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 1),
NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(1, 2),
@ -88,7 +87,9 @@ public class IndexShapeTests extends BaseNd4jTest {
@Test
public void testNewAxis() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis(Nd4jBackend backend) {
//normal prepend
int[] prependAssertion = {1, 1, 1, 1, 2, 1, 3, 4, 5, 1};
INDArrayIndex[] prependTest = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(),

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.shape;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.Indices;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -33,25 +34,26 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class IndexShapeTests2d extends BaseNd4jTest {
public IndexShapeTests2d(Nd4jBackend backend) {
super(backend);
}
public class IndexShapeTests2d extends BaseNd4jTestWithBackends {
private long[] shape = {3, 2};
@Test
public void test2dCases() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2dCases(Nd4jBackend backend) {
assertArrayEquals(new long[] {1, 2}, Indices.shape(shape, NDArrayIndex.point(1)));
assertArrayEquals(new long[] {3, 1},
Indices.shape(shape, NDArrayIndex.all(), NDArrayIndex.point(1)));
}
@Test
public void testNewAxis2d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis2d(Nd4jBackend backend) {
assertArrayEquals(new long[] {1, 3, 2}, Indices.shape(shape,
NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all()));
assertArrayEquals(new long[] {3, 1, 2}, Indices.shape(shape,

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.iterator;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,15 +34,14 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class NDIndexIteratorTest extends BaseNd4jTest {
public NDIndexIteratorTest(Nd4jBackend backend) {
super(backend);
}
public class NDIndexIteratorTest extends BaseNd4jTestWithBackends {
@Test
public void testIterate() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIterate(Nd4jBackend backend) {
val shapeIter = new NdIndexIterator(2, 2);
val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},};

View File

@ -28,9 +28,10 @@ import org.apache.commons.lang3.ArrayUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -45,16 +46,13 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j
@RunWith(Parameterized.class)
public class TestNdArrReadWriteTxt extends BaseNd4jTest {
public TestNdArrReadWriteTxt(Nd4jBackend backend) {
super(backend);
}
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
@Test
public void compareAfterWrite(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
int [] ranksToCheck = new int[] {0,1,2,3,4};
for (int i = 0; i < ranksToCheck.length; i++) {
// log.info("Checking read write arrays with rank " + ranksToCheck[i]);
@ -84,7 +82,9 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTest {
}
@Test
public void testNd4jReadWriteText(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jReadWriteText(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
int count = 0;

View File

@ -25,9 +25,10 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.nio.file.Path;
@ -35,17 +36,14 @@ import java.nio.file.Path;
import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays;
@Slf4j
@RunWith(Parameterized.class)
public class TestNdArrReadWriteTxtC extends BaseNd4jTest {
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
public TestNdArrReadWriteTxtC(Nd4jBackend backend) {
super(backend);
}
@Test
public void compareAfterWrite(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
for (int i = 0; i < ranksToCheck.length; i++) {
log.info("Checking read write arrays with rank " + ranksToCheck[i]);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.ndarray;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -32,16 +33,14 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestSerialization extends BaseNd4jTest {
public TestSerialization(Nd4jBackend backend) {
super(backend);
}
public class TestSerialization extends BaseNd4jTestWithBackends {
@Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -71,7 +70,9 @@ public class TestSerialization extends BaseNd4jTest {
}
@Test
public void testSerializationFullArrayJava() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava(Nd4jBackend backend) throws Exception {
int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -102,7 +103,9 @@ public class TestSerialization extends BaseNd4jTest {
}
@Test
public void testSerializationOnViewsNd4jWriteRead() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -138,7 +141,9 @@ public class TestSerialization extends BaseNd4jTest {
}
@Test
public void testSerializationOnViewsJava() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava(Nd4jBackend backend) throws Exception {
int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);

View File

@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.factory.Nd4j;
@ -39,15 +40,11 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
@RunWith(Parameterized.class)
public class TestSerializationDoubleToFloat extends BaseNd4jTest {
DataType initialType;
public class TestSerializationDoubleToFloat extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
public TestSerializationDoubleToFloat(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@AfterEach
public void after() {
@ -55,7 +52,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
}
@Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 4;
//WRITE OUT A DOUBLE ARRAY
@ -93,7 +92,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
}
@Test
public void testSerializationFullArrayJava() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava(Nd4jBackend backend) throws Exception {
int length = 100;
Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
@ -123,7 +124,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
}
@Test
public void testSerializationOnViewsNd4jWriteRead() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100;
Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
@ -153,7 +156,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
}
@Test
public void testSerializationOnViewsJava() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava(Nd4jBackend backend) throws Exception {
int length = 100;
Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.ndarray;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.factory.Nd4j;
@ -37,15 +38,11 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class TestSerializationFloatToDouble extends BaseNd4jTest {
DataType initialType;
public class TestSerializationFloatToDouble extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
public TestSerializationFloatToDouble(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@AfterEach
public void after() {
@ -53,7 +50,9 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
}
@Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100;
//WRITE OUT A FLOAT ARRAY
@ -86,6 +85,8 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava() throws Exception {
int length = 100;
Nd4j.create(1);
@ -117,6 +118,8 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead() throws Exception {
int length = 100;
Nd4j.create(1);
@ -147,6 +150,8 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava() throws Exception {
int length = 100;
Nd4j.create(1);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.rng;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,14 +34,13 @@ import static org.junit.jupiter.api.Assertions.*;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class RngTests extends BaseNd4jTest {
public RngTests(Nd4jBackend backend) {
super(backend);
}
public class RngTests extends BaseNd4jTestWithBackends {
@Test
public void testRngConstitency() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRngConstitency(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(123);
INDArray arr = Nd4j.rand(1, 5);
Nd4j.getRandom().setSeed(123);
@ -49,7 +49,9 @@ public class RngTests extends BaseNd4jTest {
}
@Test
public void testRandomWithOrder() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomWithOrder(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
@ -105,7 +107,9 @@ public class RngTests extends BaseNd4jTest {
}
@Test
public void testRandomBinomial() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomBinomial(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
//silly tests. Just increasing the usage for randomBinomial to stop compiler warnings.
INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3);

View File

@ -23,9 +23,10 @@ package org.nd4j.linalg.api.string;
import lombok.extern.slf4j.Slf4j;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,22 +36,23 @@ import org.nd4j.linalg.string.NDArrayStrings;
* @author Adam Gibson
*/
@Slf4j
@RunWith(Parameterized.class)
public class TestFormatting extends BaseNd4jTest {
public TestFormatting(Nd4jBackend backend) {
super(backend);
}
public class TestFormatting extends BaseNd4jTestWithBackends {
@Test
public void testTwoByTwo() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTwoByTwo(Nd4jBackend backend) {
INDArray arr = Nd4j.create(2, 2, 2, 2);
System.out.println(new NDArrayStrings().format(arr));
}
@Test
public void testNd4jArrayString() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jArrayString(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[]{1f, 20000000f, 40.838383f, 3f}, new int[]{2, 2});
@ -71,7 +73,9 @@ public class TestFormatting extends BaseNd4jTest {
}
@Test
public void testRange() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRange(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[][]{
{-1,0,1,0},
{-0.1, 0.1, -10, 10},

Some files were not shown because too many files have changed in this diff Show More