Migrate parameterized tests to junit 5

This commit is contained in:
agibsonccc 2021-03-16 22:08:35 +09:00
parent 82bdcc21d2
commit 3c6014271e
260 changed files with 10787 additions and 6270 deletions

View File

@ -37,8 +37,10 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -47,13 +49,14 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays; import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.conf.ConvolutionMode.Same; import static org.deeplearning4j.nn.conf.ConvolutionMode.Same;
import static org.deeplearning4j.nn.conf.ConvolutionMode.Truncate; import static org.deeplearning4j.nn.conf.ConvolutionMode.Truncate;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class)
@DisplayName("Cnn Gradient Check Test") @DisplayName("Cnn Gradient Check Test")
class CNNGradientCheckTest extends BaseDL4JTest { class CNNGradientCheckTest extends BaseDL4JTest {
@ -71,15 +74,10 @@ class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
private CNN2DFormat format;
public CNNGradientCheckTest(CNN2DFormat format) {
this.format = format;
}
@Parameterized.Parameters(name = "{0}") public static Stream<Arguments> params() {
public static Object[] params() { return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of);
return CNN2DFormat.values();
} }
@Override @Override
@ -89,9 +87,11 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Gradient CNNMLN") @DisplayName("Test Gradient CNNMLN")
void testGradientCNNMLN() { @ParameterizedTest
@MethodSource("#params")
public void testGradientCNNMLN(CNN2DFormat format) {
if (// Only test NCHW due to flat input format... if (// Only test NCHW due to flat input format...
this.format != CNN2DFormat.NCHW) format != CNN2DFormat.NCHW)
return; return;
// Parameterized test, testing combinations of: // Parameterized test, testing combinations of:
// (a) activation function // (a) activation function
@ -146,9 +146,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Gradient CNNL 1 L 2 MLN") @DisplayName("Test Gradient CNNL 1 L 2 MLN")
void testGradientCNNL1L2MLN() { void testGradientCNNL1L2MLN(CNN2DFormat format) {
if (// Only test NCHW due to flat input format... if (// Only test NCHW due to flat input format...
this.format != CNN2DFormat.NCHW) format != CNN2DFormat.NCHW)
return; return;
// Parameterized test, testing combinations of: // Parameterized test, testing combinations of:
// (a) activation function // (a) activation function
@ -245,7 +245,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Space To Batch") @DisplayName("Test Cnn With Space To Batch")
void testCnnWithSpaceToBatch() { @ParameterizedTest
@MethodSource("#params")
public void testCnnWithSpaceToBatch(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 2, 4 }; int[] minibatchSizes = { 2, 4 };
@ -289,7 +291,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Upsampling") @DisplayName("Test Cnn With Upsampling")
void testCnnWithUpsampling() { @ParameterizedTest
@MethodSource("#params")
void testCnnWithUpsampling(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
@ -323,7 +327,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Subsampling") @DisplayName("Test Cnn With Subsampling")
void testCnnWithSubsampling() { @ParameterizedTest
@MethodSource("#params")
void testCnnWithSubsampling(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
@ -365,7 +371,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Subsampling V 2") @DisplayName("Test Cnn With Subsampling V 2")
void testCnnWithSubsamplingV2() { @ParameterizedTest
@MethodSource("#params")
void testCnnWithSubsamplingV2(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
@ -403,7 +411,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Locally Connected 2 D") @DisplayName("Test Cnn Locally Connected 2 D")
void testCnnLocallyConnected2D() { @ParameterizedTest
@MethodSource("#params")
void testCnnLocallyConnected2D(CNN2DFormat format) {
int nOut = 3; int nOut = 3;
int width = 5; int width = 5;
int height = 5; int height = 5;
@ -433,7 +443,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Multi Layer") @DisplayName("Test Cnn Multi Layer")
void testCnnMultiLayer() { @ParameterizedTest
@MethodSource("#params")
void testCnnMultiLayer(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = { 1, 2, 5 }; int[] minibatchSizes = { 1, 2, 5 };
int width = 5; int width = 5;
@ -473,7 +485,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Same Padding Mode") @DisplayName("Test Cnn Same Padding Mode")
void testCnnSamePaddingMode() { @ParameterizedTest
@MethodSource("#params")
void testCnnSamePaddingMode(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = { 1, 3, 3, 2, 1, 2 }; int[] minibatchSizes = { 1, 3, 3, 2, 1, 2 };
// Same padding mode: insensitive to exact input size... // Same padding mode: insensitive to exact input size...
@ -507,7 +521,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Same Padding Mode Strided") @DisplayName("Test Cnn Same Padding Mode Strided")
void testCnnSamePaddingModeStrided() { @ParameterizedTest
@MethodSource("#params")
void testCnnSamePaddingModeStrided(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
int width = 16; int width = 16;
@ -550,7 +566,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Zero Padding Layer") @DisplayName("Test Cnn Zero Padding Layer")
void testCnnZeroPaddingLayer() { @ParameterizedTest
@MethodSource("#params")
void testCnnZeroPaddingLayer(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int width = 6; int width = 6;
@ -596,7 +614,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Deconvolution 2 D") @DisplayName("Test Deconvolution 2 D")
void testDeconvolution2D() { @ParameterizedTest
@MethodSource("#params")
void testDeconvolution2D(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = new int[] { 1, 3, 3, 1, 3 }; int[] minibatchSizes = new int[] { 1, 3, 3, 1, 3 };
int[] kernelSizes = new int[] { 1, 1, 1, 3, 3 }; int[] kernelSizes = new int[] { 1, 1, 1, 3, 3 };
@ -641,7 +661,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Separable Conv 2 D") @DisplayName("Test Separable Conv 2 D")
void testSeparableConv2D() { @ParameterizedTest
@MethodSource("#params")
void testSeparableConv2D(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int width = 6; int width = 6;
int height = 6; int height = 6;
@ -686,7 +708,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Dilated") @DisplayName("Test Cnn Dilated")
void testCnnDilated() { @ParameterizedTest
@MethodSource("#params")
void testCnnDilated(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int minibatchSize = 2; int minibatchSize = 2;
int width = 8; int width = 8;
@ -736,7 +760,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cropping 2 D Layer") @DisplayName("Test Cropping 2 D Layer")
void testCropping2DLayer() { @ParameterizedTest
@MethodSource("#params")
void testCropping2DLayer(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 2; int nOut = 2;
int width = 12; int width = 12;
@ -780,7 +806,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Depthwise Conv 2 D") @DisplayName("Test Depthwise Conv 2 D")
void testDepthwiseConv2D() { @ParameterizedTest
@MethodSource("#params")
void testDepthwiseConv2D(CNN2DFormat format) {
int nIn = 3; int nIn = 3;
int depthMultiplier = 2; int depthMultiplier = 2;
int nOut = nIn * depthMultiplier; int nOut = nIn * depthMultiplier;

View File

@ -39,8 +39,10 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -55,26 +57,22 @@ import java.io.File;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.InputStream; import java.io.InputStream;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class YoloGradientCheckTests extends BaseDL4JTest { public class YoloGradientCheckTests extends BaseDL4JTest {
static { static {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
private CNN2DFormat format;
public YoloGradientCheckTests(CNN2DFormat format){
this.format = format;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return CNN2DFormat.values();
}
public static Stream<Arguments> params() {
return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of);
}
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
@ -82,7 +80,9 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
} }
@Test @Test
public void testYoloOutputLayer() { @ParameterizedTest
@MethodSource("#params")
public void testYoloOutputLayer(CNN2DFormat format) {
int depthIn = 2; int depthIn = 2;
int c = 3; int c = 3;
int b = 3; int b = 3;
@ -159,13 +159,13 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
} }
} }
private static INDArray yoloLabels(int mb, int c, int h, int w){ private static INDArray yoloLabels(int mb, int c, int h, int w) {
int labelDepth = 4 + c; int labelDepth = 4 + c;
INDArray labels = Nd4j.zeros(mb, labelDepth, h, w); INDArray labels = Nd4j.zeros(mb, labelDepth, h, w);
//put 1 object per minibatch, at positions (0,0), (1,1) etc. //put 1 object per minibatch, at positions (0,0), (1,1) etc.
//Positions for label boxes: (1,1) to (2,2), (2,2) to (4,4) etc //Positions for label boxes: (1,1) to (2,2), (2,2) to (4,4) etc
for( int i=0; i<mb; i++ ){ for( int i = 0; i < mb; i++) {
//Class labels //Class labels
labels.putScalar(i, 4 + i%c, i%h, i%w, 1); labels.putScalar(i, 4 + i%c, i%h, i%w, 1);
@ -181,7 +181,9 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
@Test @Test
public void yoloGradientCheckRealData(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("#params")
public void yoloGradientCheckRealData(@TempDir Path testDir,CNN2DFormat format) throws Exception {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
InputStream is1 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2007_009346.jpg").getInputStream(); InputStream is1 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2007_009346.jpg").getInputStream();
InputStream is2 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2007_009346.xml").getInputStream(); InputStream is2 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2007_009346.xml").getInputStream();

View File

@ -39,8 +39,10 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -48,24 +50,19 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class ConvDataFormatTests extends BaseDL4JTest { public class ConvDataFormatTests extends BaseDL4JTest {
private final DataType dataType;
public ConvDataFormatTests(DataType dataType){ public static Stream<Arguments> params(){
this.dataType = dataType; return Arrays.asList(new DataType[]{DataType.FLOAT, DataType.DOUBLE}).stream().map(Arguments::of);
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return new DataType[]{DataType.FLOAT, DataType.DOUBLE};
} }
@Override @Override
@ -74,7 +71,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testConv2d() { @MethodSource("#params")
@ParameterizedTest
public void testConv2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -83,15 +82,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getConv2dNet(CNN2DFormat.NCHW, true, cm)) .net1(getConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getConv2dNet(CNN2DFormat.NCHW, false, cm)) .net2(getConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getConv2dNet(CNN2DFormat.NHWC, true, cm)) .net3(getConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getConv2dNet(CNN2DFormat.NHWC, false, cm)) .net4(getConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -107,7 +106,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testSubsampling2d() { @MethodSource("#params")
@ParameterizedTest
public void testSubsampling2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -116,15 +117,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm)) .net1(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm)) .net2(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm)) .net3(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm)) .net4(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -140,7 +141,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testDepthwiseConv2d() { @MethodSource("#params")
@ParameterizedTest
public void testDepthwiseConv2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -149,15 +152,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm)) .net1(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm)) .net2(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm)) .net3(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm)) .net4(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -173,7 +176,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testSeparableConv2d() { @MethodSource("#params")
@ParameterizedTest
public void testSeparableConv2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -182,15 +187,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm)) .net1(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm)) .net2(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm)) .net3(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm)) .net4(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -206,7 +211,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testDeconv2d() { @MethodSource("#params")
@ParameterizedTest
public void testDeconv2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -215,15 +222,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm)) .net1(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm)) .net2(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm)) .net3(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm)) .net4(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -239,7 +246,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testLRN() { @MethodSource("#params")
@ParameterizedTest
public void testLRN(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -248,15 +257,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getLrnLayer(CNN2DFormat.NCHW, true, cm)) .net1(getLrnLayer(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getLrnLayer(CNN2DFormat.NCHW, false, cm)) .net2(getLrnLayer(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getLrnLayer(CNN2DFormat.NHWC, true, cm)) .net3(getLrnLayer(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getLrnLayer(CNN2DFormat.NHWC, false, cm)) .net4(getLrnLayer(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -272,7 +281,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testZeroPaddingLayer(){ @MethodSource("#params")
@ParameterizedTest
public void testZeroPaddingLayer(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -280,15 +291,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getZeroPaddingNet(CNN2DFormat.NCHW, true)) .net1(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, true))
.net2(getZeroPaddingNet(CNN2DFormat.NCHW, false)) .net2(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, false))
.net3(getZeroPaddingNet(CNN2DFormat.NHWC, true)) .net3(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, true))
.net4(getZeroPaddingNet(CNN2DFormat.NHWC, false)) .net4(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -303,7 +314,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testCropping2DLayer(){ @MethodSource("#params")
@ParameterizedTest
public void testCropping2DLayer(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -311,15 +324,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getCropping2dNet(CNN2DFormat.NCHW, true)) .net1(getCropping2dNet(dataType,CNN2DFormat.NCHW, true))
.net2(getCropping2dNet(CNN2DFormat.NCHW, false)) .net2(getCropping2dNet(dataType,CNN2DFormat.NCHW, false))
.net3(getCropping2dNet(CNN2DFormat.NHWC, true)) .net3(getCropping2dNet(dataType,CNN2DFormat.NHWC, true))
.net4(getCropping2dNet(CNN2DFormat.NHWC, false)) .net4(getCropping2dNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -334,7 +347,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testUpsampling2d(){ @MethodSource("#params")
@ParameterizedTest
public void testUpsampling2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -342,15 +357,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getUpsamplingNet(CNN2DFormat.NCHW, true)) .net1(getUpsamplingNet(dataType,CNN2DFormat.NCHW, true))
.net2(getUpsamplingNet(CNN2DFormat.NCHW, false)) .net2(getUpsamplingNet(dataType,CNN2DFormat.NCHW, false))
.net3(getUpsamplingNet(CNN2DFormat.NHWC, true)) .net3(getUpsamplingNet(dataType,CNN2DFormat.NHWC, true))
.net4(getUpsamplingNet(CNN2DFormat.NHWC, false)) .net4(getUpsamplingNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -365,7 +380,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testBatchNormNet(){ @MethodSource("#params")
@ParameterizedTest
public void testBatchNormNet(DataType dataType) {
try { try {
for(boolean useLogStd : new boolean[]{true, false}) { for(boolean useLogStd : new boolean[]{true, false}) {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
@ -374,15 +391,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std"); String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std");
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true)) .net1(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, true))
.net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false)) .net2(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, false))
.net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true)) .net3(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, true))
.net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false)) .net4(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -398,7 +415,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testCnnLossLayer() { @MethodSource("#params")
@ParameterizedTest
public void testCnnLossLayer(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -406,8 +425,8 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3); INDArray labelsNHWC = TestUtils.randomOneHot(dataType,2*6*6, 3);
labelsNHWC = labelsNHWC.reshape(2,6,6,3); labelsNHWC = labelsNHWC.reshape(2,6,6,3);
INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup(); INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup();
@ -434,7 +453,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testSpaceToDepthNet(){ @MethodSource("#params")
@ParameterizedTest
public void testSpaceToDepthNet(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -442,15 +463,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true)) .net1(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, true))
.net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false)) .net2(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, false))
.net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true)) .net3(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, true))
.net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false)) .net4(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -465,7 +486,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testSpaceToBatchNet(){ @MethodSource("#params")
@ParameterizedTest
public void testSpaceToBatchNet(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -473,15 +496,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 16, 16);
INDArray labels = TestUtils.randomOneHot(8, 10); INDArray labels = TestUtils.randomOneHot(8, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true)) .net1(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, true))
.net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false)) .net2(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, false))
.net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true)) .net3(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, true))
.net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false)) .net4(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -496,7 +519,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testLocallyConnected() { @MethodSource("#params")
@ParameterizedTest
public void testLocallyConnected(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -505,15 +530,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm)) .net1(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm)) .net2(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm)) .net3(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm)) .net4(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -530,7 +555,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
@Test @Test
public void testGlobalPooling() { @MethodSource("#params")
@ParameterizedTest
public void testGlobalPooling(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (PoolingType pt : PoolingType.values()) { for (PoolingType pt : PoolingType.values()) {
@ -539,15 +566,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")"; String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true)) .net1(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false)) .net2(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true)) .net3(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false)) .net4(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -562,9 +589,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new ConvolutionLayer.Builder() return getNetWithLayer(dataType,new ConvolutionLayer.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -573,7 +600,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new ConvolutionLayer.Builder() return getNetWithLayer(dataType,new ConvolutionLayer.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -583,16 +610,16 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getSubsampling2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new SubsamplingLayer.Builder() return getNetWithLayer(dataType,new SubsamplingLayer.Builder()
.kernelSize(2, 2) .kernelSize(2, 2)
.stride(1, 1) .stride(1, 1)
.dataFormat(format) .dataFormat(format)
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new SubsamplingLayer.Builder() return getNetWithLayer(dataType,new SubsamplingLayer.Builder()
.kernelSize(2, 2) .kernelSize(2, 2)
.stride(1, 1) .stride(1, 1)
.helperAllowFallback(false) .helperAllowFallback(false)
@ -600,9 +627,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getSeparableConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new SeparableConvolution2D.Builder() return getNetWithLayer(dataType,new SeparableConvolution2D.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -611,7 +638,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new SeparableConvolution2D.Builder() return getNetWithLayer(dataType,new SeparableConvolution2D.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -621,9 +648,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getDepthwiseConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new DepthwiseConvolution2D.Builder() return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder()
.depthMultiplier(2) .depthMultiplier(2)
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
@ -633,7 +660,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new DepthwiseConvolution2D.Builder() return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder()
.depthMultiplier(2) .depthMultiplier(2)
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
@ -644,59 +671,59 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getLrnLayer(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new LocalResponseNormalization.Builder() return getNetWithLayer(dataType,new LocalResponseNormalization.Builder()
.dataFormat(format) .dataFormat(format)
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new LocalResponseNormalization.Builder() return getNetWithLayer(dataType,new LocalResponseNormalization.Builder()
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} }
} }
private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getZeroPaddingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2) return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null); .dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(), return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2).build(),
format, ConvolutionMode.Same, null); format, ConvolutionMode.Same, null);
} }
} }
private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getCropping2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new Cropping2D.Builder(2,2) return getNetWithLayer(dataType,new Cropping2D.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null); .dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new Cropping2D.Builder(2,2) return getNetWithLayer(dataType,new Cropping2D.Builder(2,2)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} }
} }
private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getUpsamplingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new Upsampling2D.Builder(2) return getNetWithLayer(dataType,new Upsampling2D.Builder(2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null); .dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new Upsampling2D.Builder(2) return getNetWithLayer(dataType,new Upsampling2D.Builder(2)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} }
} }
private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getDeconv2DNet2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH) .activation(Activation.TANH)
.kernelSize(2,2) .kernelSize(2,2)
.dataFormat(format) .dataFormat(format)
.stride(2,2) .stride(2,2)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH) .activation(Activation.TANH)
.kernelSize(2,2) .kernelSize(2,2)
.dataFormat(format) .dataFormat(format)
@ -705,50 +732,50 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getBatchNormNet(DataType dataType,boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new BatchNormalization.Builder() return getNetWithLayer(dataType,new BatchNormalization.Builder()
.useLogStd(logStdev) .useLogStd(logStdev)
.dataFormat(format) .dataFormat(format)
.helperAllowFallback(false) .helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null); .nOut(3).build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new BatchNormalization.Builder() return getNetWithLayer(dataType,new BatchNormalization.Builder()
.useLogStd(logStdev) .useLogStd(logStdev)
.helperAllowFallback(false) .helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null); .nOut(3).build(), format, ConvolutionMode.Same, null);
} }
} }
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getSpaceToDepthNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToDepthLayer.Builder() return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder()
.blocks(2) .blocks(2)
.dataFormat(format) .dataFormat(format)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new SpaceToDepthLayer.Builder() return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder()
.blocks(2) .blocks(2)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} }
} }
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getSpaceToBatchNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToBatchLayer.Builder() return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder()
.blocks(2, 2) .blocks(2, 2)
.dataFormat(format) .dataFormat(format)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
} else { } else {
return getNetWithLayer(new SpaceToBatchLayer.Builder() return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder()
.blocks(2, 2) .blocks(2, 2)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
} }
} }
private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getLocallyConnectedNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new LocallyConnected2D.Builder() return getNetWithLayer(dataType,new LocallyConnected2D.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -756,7 +783,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.nOut(3) .nOut(3)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new LocallyConnected2D.Builder() return getNetWithLayer(dataType,new LocallyConnected2D.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -765,9 +792,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) { private MultiLayerNetwork getNetWithLayer(DataType dataType,Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.dataType(this.dataType) .dataType(dataType)
.seed(12345) .seed(12345)
.convolutionMode(cm) .convolutionMode(cm)
.list() .list()
@ -794,13 +821,13 @@ public class ConvDataFormatTests extends BaseDL4JTest {
return net; return net;
} }
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) { private MultiLayerNetwork getGlobalPoolingNet(DataType dataType,CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt)
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2}) .poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} }
} }

View File

@ -45,8 +45,11 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.TimeSeriesUtils; import org.deeplearning4j.util.TimeSeriesUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -61,30 +64,29 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.conf.RNNFormat.NCW; import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
@DisplayName("Bidirectional Test") @DisplayName("Bidirectional Test")
class BidirectionalTest extends BaseDL4JTest { class BidirectionalTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public BidirectionalTest(RNNFormat rnnDataFormat) {
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters public static Stream<Arguments> params() {
public static Object[] params() { return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
return RNNFormat.values();
} }
@Test @Test
@DisplayName("Compare Implementations") @DisplayName("Compare Implementations")
void compareImplementations() { @ParameterizedTest
@MethodSource("#params")
void compareImplementations(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
// Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params // Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params
@ -147,9 +149,11 @@ class BidirectionalTest extends BaseDL4JTest {
} }
} }
@Test
@DisplayName("Compare Implementations Comp Graph") @DisplayName("Compare Implementations Comp Graph")
void compareImplementationsCompGraph() { @Test
@ParameterizedTest
@MethodSource("#params")
void compareImplementationsCompGraph(RNNFormat rnnFormat) {
// for(WorkspaceMode wsm : WorkspaceMode.values()) { // for(WorkspaceMode wsm : WorkspaceMode.values()) {
for (WorkspaceMode wsm : new WorkspaceMode[] { WorkspaceMode.NONE, WorkspaceMode.ENABLED }) { for (WorkspaceMode wsm : new WorkspaceMode[] { WorkspaceMode.NONE, WorkspaceMode.ENABLED }) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
@ -187,8 +191,8 @@ class BidirectionalTest extends BaseDL4JTest {
Gradient g2 = net2.gradient(); Gradient g2 = net2.gradient();
assertEquals(g1.gradient(), g2.gradient()); assertEquals(g1.gradient(), g2.gradient());
// Ensure updates are equal: // Ensure updates are equal:
ComputationGraphUpdater u1 = (ComputationGraphUpdater) net1.getUpdater(); ComputationGraphUpdater u1 = net1.getUpdater();
ComputationGraphUpdater u2 = (ComputationGraphUpdater) net2.getUpdater(); ComputationGraphUpdater u2 = net2.getUpdater();
assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray()); assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray());
u1.update(g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); u1.update(g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
u2.update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); u2.update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
@ -205,7 +209,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Serialization") @DisplayName("Test Serialization")
void testSerialization() throws Exception { @ParameterizedTest
@MethodSource("#params")
void testSerialization(RNNFormat rnnDataFormat) throws Exception {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -242,7 +248,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Serialization Comp Graph") @DisplayName("Test Serialization Comp Graph")
void testSerializationCompGraph() throws Exception { @ParameterizedTest
@MethodSource("#params")
void testSerializationCompGraph(RNNFormat rnnDataFormat) throws Exception {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -277,7 +285,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Simple Bidirectional") @DisplayName("Test Simple Bidirectional")
void testSimpleBidirectional() { @ParameterizedTest
@MethodSource("#params")
public void testSimpleBidirectional(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -362,7 +372,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Simple Bidirectional Comp Graph") @DisplayName("Test Simple Bidirectional Comp Graph")
void testSimpleBidirectionalCompGraph() { @ParameterizedTest
@MethodSource("#params")
void testSimpleBidirectionalCompGraph(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -19,7 +19,6 @@
*/ */
package org.deeplearning4j.nn.layers.recurrent; package org.deeplearning4j.nn.layers.recurrent;
import junit.framework.TestCase;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -34,9 +33,12 @@ import org.deeplearning4j.nn.params.GravesBidirectionalLSTMParamInitializer;
import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -44,31 +46,29 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.primitives.Pair;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class) import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Graves Bidirectional LSTM Test") @DisplayName("Graves Bidirectional LSTM Test")
class GravesBidirectionalLSTMTest extends BaseDL4JTest { class GravesBidirectionalLSTMTest extends BaseDL4JTest {
private double score = 0.0; private double score = 0.0;
private RNNFormat rnnDataFormat;
public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat) {
this.rnnDataFormat = rnnDataFormat; public static Stream<Arguments> params(){
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
} }
@Parameterized.Parameters
public static Object[] params() {
return RNNFormat.values();
}
@Test @Test
@DisplayName("Test Bidirectional LSTM Graves Forward Basic") @DisplayName("Test Bidirectional LSTM Graves Forward Basic")
void testBidirectionalLSTMGravesForwardBasic() { @MethodSource("#params")
@ParameterizedTest
void testBidirectionalLSTMGravesForwardBasic(RNNFormat rnnDataFormat) {
// Very basic test of forward prop. of LSTM layer with a time series. // Very basic test of forward prop. of LSTM layer with a time series.
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
int nIn = 13; int nIn = 13;
@ -110,19 +110,21 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Bidirectional LSTM Graves Backward Basic") @DisplayName("Test Bidirectional LSTM Graves Backward Basic")
void testBidirectionalLSTMGravesBackwardBasic() { @MethodSource("#params")
@ParameterizedTest
void testBidirectionalLSTMGravesBackwardBasic(RNNFormat rnnDataFormat) {
// Very basic test of backprop for mini-batch + time series // Very basic test of backprop for mini-batch + time series
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
testGravesBackwardBasicHelper(13, 3, 17, 10, 7); testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 7);
// Edge case: miniBatchSize = 1 // Edge case: miniBatchSize = 1
testGravesBackwardBasicHelper(13, 3, 17, 1, 7); testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 7);
// Edge case: timeSeriesLength = 1 // Edge case: timeSeriesLength = 1
testGravesBackwardBasicHelper(13, 3, 17, 10, 1); testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 1);
// Edge case: both miniBatchSize = 1 and timeSeriesLength = 1 // Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
testGravesBackwardBasicHelper(13, 3, 17, 1, 1); testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 1);
} }
private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) { private void testGravesBackwardBasicHelper(RNNFormat rnnDataFormat,int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) {
INDArray inputData = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.ones(miniBatchSize, nIn, timeSeriesLength) : Nd4j.ones(miniBatchSize, timeSeriesLength, nIn); INDArray inputData = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.ones(miniBatchSize, nIn, timeSeriesLength) : Nd4j.ones(miniBatchSize, timeSeriesLength, nIn);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build(); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build();
long numParams = conf.getLayer().initializer().numParams(conf); long numParams = conf.getLayer().initializer().numParams(conf);
@ -204,7 +206,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Get Set Parmas") @DisplayName("Test Get Set Parmas")
void testGetSetParmas() { @MethodSource("#params")
@ParameterizedTest
void testGetSetParmas(RNNFormat rnnDataFormat) {
final int nIn = 2; final int nIn = 2;
final int layerSize = 3; final int layerSize = 3;
final int miniBatchSize = 2; final int miniBatchSize = 2;
@ -224,7 +228,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Simple Forwards And Backwards Activation") @DisplayName("Test Simple Forwards And Backwards Activation")
void testSimpleForwardsAndBackwardsActivation() { @MethodSource("#params")
@ParameterizedTest
void testSimpleForwardsAndBackwardsActivation(RNNFormat rnnDataFormat) {
final int nIn = 2; final int nIn = 2;
final int layerSize = 3; final int layerSize = 3;
final int miniBatchSize = 1; final int miniBatchSize = 1;
@ -342,7 +348,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Gate Activation Fns Sanity Check") @DisplayName("Test Gate Activation Fns Sanity Check")
void testGateActivationFnsSanityCheck() { @MethodSource("#params")
@ParameterizedTest
void testGateActivationFnsSanityCheck(RNNFormat rnnDataFormat) {
for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) { for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);

View File

@ -30,36 +30,35 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class)
@DisplayName("Mask Zero Layer Test") @DisplayName("Mask Zero Layer Test")
class MaskZeroLayerTest extends BaseDL4JTest { class MaskZeroLayerTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public MaskZeroLayerTest(RNNFormat rnnDataFormat) { public static Stream<Arguments> params() {
this.rnnDataFormat = rnnDataFormat; return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
} }
@Parameterized.Parameters
public static Object[] params() {
return RNNFormat.values();
}
@Test
@DisplayName("Activate") @DisplayName("Activate")
void activate() { @Test
@ParameterizedTest
@MethodSource("#params")
void activate(RNNFormat rnnDataFormat) {
// GIVEN two examples where some of the timesteps are zero. // GIVEN two examples where some of the timesteps are zero.
INDArray ex1 = Nd4j.create(new double[][] { new double[] { 0, 3, 5 }, new double[] { 0, 0, 2 } }); INDArray ex1 = Nd4j.create(new double[][] { new double[] { 0, 3, 5 }, new double[] { 0, 0, 2 } });
INDArray ex2 = Nd4j.create(new double[][] { new double[] { 0, 0, 2 }, new double[] { 0, 0, 2 } }); INDArray ex2 = Nd4j.create(new double[][] { new double[] { 0, 0, 2 }, new double[] { 0, 0, 2 } });
@ -95,9 +94,12 @@ class MaskZeroLayerTest extends BaseDL4JTest {
assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6); assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6);
} }
@Test
@DisplayName("Test Serialization") @DisplayName("Test Serialization")
void testSerialization() { @Test
@ParameterizedTest
@MethodSource("#params")
void testSerialization(RNNFormat rnnDataFormat) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder().setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()).build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder().setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();

View File

@ -40,8 +40,10 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -51,30 +53,31 @@ import org.nd4j.common.primitives.Pair;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@AllArgsConstructor @AllArgsConstructor
public class RnnDataFormatTests extends BaseDL4JTest { public class RnnDataFormatTests extends BaseDL4JTest {
private boolean helpers;
private boolean lastTimeStep;
private boolean maskZeros;
@Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}") public static Stream<Arguments> params() {
public static List params(){
List<Object[]> ret = new ArrayList<>(); List<Object[]> ret = new ArrayList<>();
for (boolean helpers: new boolean[]{true, false}) for (boolean helpers: new boolean[]{true, false})
for (boolean lastTimeStep: new boolean[]{true, false}) for (boolean lastTimeStep: new boolean[]{true, false})
for (boolean maskZero: new boolean[]{true, false}) for (boolean maskZero: new boolean[]{true, false})
ret.add(new Object[]{helpers, lastTimeStep, maskZero}); ret.add(new Object[]{helpers, lastTimeStep, maskZero});
return ret; return ret.stream().map(Arguments::of);
} }
@Test @Test
public void testSimpleRnn() { @MethodSource("#params")
@ParameterizedTest
public void testSimpleRnn(boolean helpers,
boolean lastTimeStep,
boolean maskZeros
) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -107,7 +110,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testLSTM() { @ParameterizedTest
@MethodSource("#params")
public void testLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -141,7 +148,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
@Test @Test
public void testGraveLSTM() { @MethodSource("#params")
@ParameterizedTest
public void testGraveLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -175,7 +186,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
@Test @Test
public void testGraveBiLSTM() { @MethodSource("#params")
@ParameterizedTest
public void testGraveBiLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -34,14 +34,20 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.AdaGrad;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
import static org.deeplearning4j.nn.weights.WeightInit.XAVIER_UNIFORM; import static org.deeplearning4j.nn.weights.WeightInit.XAVIER_UNIFORM;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@ -50,20 +56,16 @@ import static org.nd4j.linalg.activations.Activation.TANH;
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE; import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
@RunWith(Parameterized.class)
public class TestLastTimeStepLayer extends BaseDL4JTest { public class TestLastTimeStepLayer extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestLastTimeStepLayer(RNNFormat rnnDataFormat){ public static Stream<Arguments> params(){
this.rnnDataFormat = rnnDataFormat; return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters(name="{0}")
public static Object[] params(){
return RNNFormat.values();
} }
@Test @Test
public void testLastTimeStepVertex() { @ParameterizedTest
@MethodSource("#params")
public void testLastTimeStepVertex(RNNFormat rnnDataFormat) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder() .addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
@ -126,7 +128,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
} }
@Test @Test
public void testMaskingAndAllMasked(){ @ParameterizedTest
@MethodSource("#params")
public void testMaskingAndAllMasked(RNNFormat rnnDataFormat) {
ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder()
.optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT) .optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT)
.weightInit(XAVIER_UNIFORM) .weightInit(XAVIER_UNIFORM)

View File

@ -36,8 +36,11 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -49,25 +52,23 @@ import org.nd4j.common.primitives.Pair;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class TestRnnLayers extends BaseDL4JTest { public class TestRnnLayers extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestRnnLayers(RNNFormat rnnDataFormat){ public static Stream<Arguments> params(){
this.rnnDataFormat = rnnDataFormat; return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
} }
@Test @Test
public void testTimeStepIs3Dimensional() { @ParameterizedTest
@MethodSource("#params")
public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat) {
int nIn = 12; int nIn = 12;
int nOut = 3; int nOut = 3;
@ -117,7 +118,9 @@ public class TestRnnLayers extends BaseDL4JTest {
} }
@Test @Test
public void testDropoutRecurrentLayers(){ @ParameterizedTest
@MethodSource("#params")
public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
String[] layerTypes = new String[]{"graves", "lstm", "simple"}; String[] layerTypes = new String[]{"graves", "lstm", "simple"};
@ -215,9 +218,11 @@ public class TestRnnLayers extends BaseDL4JTest {
} }
@Test @Test
public void testMismatchedInputLabelLength(){ @ParameterizedTest
@MethodSource("#params")
public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat){
for( int i=0; i<2; i++ ){ for( int i = 0; i < 2; i++) {
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder() NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()

View File

@ -29,8 +29,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,25 +40,25 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point; import static org.nd4j.linalg.indexing.NDArrayIndex.point;
@RunWith(Parameterized.class)
public class TestSimpleRnn extends BaseDL4JTest { public class TestSimpleRnn extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestSimpleRnn(RNNFormat rnnDataFormat){ public static Stream<Arguments> params() {
this.rnnDataFormat = rnnDataFormat; return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
} }
@Test @Test
public void testSimpleRnn(){ @ParameterizedTest
@MethodSource("#params")
public void testSimpleRnn(RNNFormat rnnDataFormat) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int m = 3; int m = 3;
@ -125,7 +127,9 @@ public class TestSimpleRnn extends BaseDL4JTest {
} }
@Test @Test
public void testBiasInit(){ @ParameterizedTest
@MethodSource("#params")
public void testBiasInit(RNNFormat rnnDataFormat) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 5; int nIn = 5;
int layerSize = 6; int layerSize = 6;

View File

@ -37,8 +37,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -47,22 +49,22 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestTimeDistributed extends BaseDL4JTest { public class TestTimeDistributed extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestTimeDistributed(RNNFormat rnnDataFormat){ public static Stream<Arguments> params(){
this.rnnDataFormat = rnnDataFormat; return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
} }
@Test @Test
public void testTimeDistributed(){ @ParameterizedTest
@MethodSource("#params")
public void testTimeDistributed(RNNFormat rnnDataFormat){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) { for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -133,10 +135,12 @@ public class TestTimeDistributed extends BaseDL4JTest {
@Test @Test
public void testTimeDistributedDense(){ @MethodSource("#params")
@ParameterizedTest
public void testTimeDistributedDense(RNNFormat rnnDataFormat){
for( int rnnType=0; rnnType<3; rnnType++ ) { for( int rnnType = 0; rnnType < 3; rnnType++ ) {
for( int ffType=0; ffType<3; ffType++ ) { for( int ffType = 0; ffType < 3; ffType++ ) {
Layer l0, l2; Layer l0, l2;
switch (rnnType) { switch (rnnType) {

View File

@ -39,8 +39,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -145,92 +145,6 @@
</execution> </execution>
</executions> </executions>
</plugin> </plugin>
<plugin>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-plugin</artifactId>
<version>1.4.30-M1</version>
<configuration>
<args>
<arg>-Xjsr305=strict</arg>
</args>
<compilerPlugins>
<plugin>spring</plugin>
<plugin>jpa</plugin>
</compilerPlugins>
</configuration>
<dependencies>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-allopen</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-noarg</artifactId>
<version>${kotlin.version}</version>
</dependency>
</dependencies>
<executions>
<execution>
<id>compile</id>
<goals> <goal>compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/main/stubs</sourceDir>
<sourceDir>${project.basedir}/src/main/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/main/java</sourceDir>
<sourceDir>${project.basedir}/src/main/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
<execution>
<id>test-compile</id>
<goals> <goal>test-compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/test/stubs</sourceDir>
<sourceDir>${project.basedir}/src/test/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/test/java</sourceDir>
<sourceDir>${project.basedir}/src/test/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
</executions>
</plugin>
<!-- https://kotlinlang.org/docs/reference/using-maven.html -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.5.1</version>
<executions>
<!-- Replacing default-compile as it is treated specially by maven -->
<execution>
<id>default-compile</id>
<phase>none</phase>
</execution>
<!-- Replacing default-testCompile as it is treated specially by maven -->
<execution>
<id>default-testCompile</id>
<phase>none</phase>
</execution>
<execution>
<id>java-compile</id>
<phase>compile</phase>
<goals> <goal>compile</goal> </goals>
</execution>
<execution>
<id>java-test-compile</id>
<phase>test-compile</phase>
<goals> <goal>testCompile</goal> </goals>
</execution>
</executions>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins> </plugins>
</build> </build>
@ -244,7 +158,10 @@
<groupId>org.junit.jupiter</groupId> <groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId> <artifactId>junit-jupiter-engine</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.jetbrains.kotlin</groupId> <groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib-jdk8</artifactId> <artifactId>kotlin-stdlib-jdk8</artifactId>
@ -261,11 +178,14 @@
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>samediff-import-tensorflow</artifactId> <artifactId>samediff-import-tensorflow</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>compile</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>samediff-import-onnx</artifactId> <artifactId>samediff-import-onnx</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>compile</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>

View File

@ -22,11 +22,6 @@ package org.nd4j;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass; import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.imports.tfgraphs.TFGraphTestAllLibnd4j;
import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
import org.nd4j.imports.tfgraphs.TFGraphTestList;
import org.nd4j.imports.tfgraphs.TFGraphTestZooModels;
import org.nd4j.imports.listeners.ImportModelDebugger;
import java.util.*; import java.util.*;
@Slf4j @Slf4j
@ -36,11 +31,6 @@ public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
protected Set<Class<?>> getExclusions() { protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>(Arrays.asList( return new HashSet<>(Arrays.asList(
TFGraphTestAllSameDiff.class,
TFGraphTestAllLibnd4j.class,
TFGraphTestList.class,
TFGraphTestZooModels.class,
ImportModelDebugger.class //Run manually only, otherwise ignored
)); ));
} }

View File

@ -20,19 +20,16 @@
package org.nd4j; package org.nd4j;
import org.bytedeco.javacpp.Loader;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Suite; import org.junit.runners.Suite;
import org.nd4j.autodiff.opvalidation.*; import org.nd4j.autodiff.opvalidation.*;
import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff; //import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.function.Function;
import static org.junit.Assume.assumeFalse; import static org.junit.Assume.assumeFalse;
@ -49,7 +46,7 @@ import static org.junit.Assume.assumeFalse;
TransformOpValidation.class, TransformOpValidation.class,
//TF import tests //TF import tests
TFGraphTestAllSameDiff.class //TFGraphTestAllSameDiff.class
//TFGraphTestAllLibnd4j.class //TFGraphTestAllLibnd4j.class
}) })
//IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test" //IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test"

View File

@ -27,10 +27,12 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.ImportClassMapping; import org.nd4j.imports.converters.ImportClassMapping;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.NoOp; import org.nd4j.linalg.api.ops.NoOp;
import org.nd4j.linalg.api.ops.compat.CompatSparseToDense; import org.nd4j.linalg.api.ops.compat.CompatSparseToDense;
@ -122,13 +124,11 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
@Disabled("No longer relevant after model import rewrite.") @Disabled("No longer relevant after model import rewrite.")
public class TestOpMapping extends BaseNd4jTest { public class TestOpMapping extends BaseNd4jTestWithBackends {
Set<Class<? extends DifferentialFunction>> subTypes; Set<Class<? extends DifferentialFunction>> subTypes;
public TestOpMapping(Nd4jBackend b){ public TestOpMapping() {
super(b);
Reflections reflections = new Reflections("org.nd4j"); Reflections reflections = new Reflections("org.nd4j");
subTypes = reflections.getSubTypesOf(DifferentialFunction.class); subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
} }
@ -146,6 +146,8 @@ public class TestOpMapping extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpMappingCoverage() throws Exception { public void testOpMappingCoverage() throws Exception {
Map<String, DifferentialFunction> opNameMapping = ImportClassMapping.getOpNameMapping(); Map<String, DifferentialFunction> opNameMapping = ImportClassMapping.getOpNameMapping();
Map<String, DifferentialFunction> tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions(); Map<String, DifferentialFunction> tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions();
@ -196,7 +198,9 @@ public class TestOpMapping extends BaseNd4jTest {
} }
@Test @Test
public void testOpsInNamespace() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpsInNamespace(Nd4jBackend backend) throws Exception {
//Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't //Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't
// want to add to a namespace for some reason) // want to add to a namespace for some reason)
//Note that we ignore "*Bp", "*Gradient", "*Derivative" etc ops //Note that we ignore "*Bp", "*Gradient", "*Derivative" etc ops
@ -354,8 +358,11 @@ public class TestOpMapping extends BaseNd4jTest {
s.add(Assign.class); s.add(Assign.class);
} }
@Test @Disabled @Test
public void generateOpClassList() throws Exception{ @Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void generateOpClassList(Nd4jBackend backend) throws Exception{
Reflections reflections = new Reflections("org.nd4j"); Reflections reflections = new Reflections("org.nd4j");
Set<Class<? extends DifferentialFunction>> subTypes = reflections.getSubTypesOf(DifferentialFunction.class); Set<Class<? extends DifferentialFunction>> subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
@ -366,12 +373,7 @@ public class TestOpMapping extends BaseNd4jTest {
l.add(c); l.add(c);
} }
Collections.sort(l, new Comparator<Class<?>>() { Collections.sort(l, Comparator.comparing(Class::getName));
@Override
public int compare(Class<?> o1, Class<?> o2) {
return o1.getName().compareTo(o2.getName());
}
});
for(Class<?> c : l){ for(Class<?> c : l){
System.out.println(c.getName() + ".class,"); System.out.println(c.getName() + ".class,");

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -31,7 +33,7 @@ import org.nd4j.autodiff.samediff.internal.FrameIter;
import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.memory.NoOpMemoryMgr; import org.nd4j.autodiff.samediff.internal.memory.NoOpMemoryMgr;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -46,19 +48,17 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class TestSessions extends BaseNd4jTest { public class TestSessions extends BaseNd4jTestWithBackends {
public TestSessions(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering() {
return 'c'; return 'c';
} }
@Test @Test
public void testInferenceSessionBasic(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInferenceSessionBasic(Nd4jBackend backend) {
//So far: trivial test to check execution order //So far: trivial test to check execution order
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -90,7 +90,9 @@ public class TestSessions extends BaseNd4jTest {
@Test @Test
public void testInferenceSessionBasic2(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInferenceSessionBasic2(Nd4jBackend backend) {
//So far: trivial test to check execution order //So far: trivial test to check execution order
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -126,7 +128,9 @@ public class TestSessions extends BaseNd4jTest {
} }
@Test @Test
public void testMergeSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeSimple(Nd4jBackend backend) {
//This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available... //This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available...
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -162,7 +166,9 @@ public class TestSessions extends BaseNd4jTest {
@Test @Test
public void testSwitchSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSwitchSimple(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3,3); SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3,3);

View File

@ -21,10 +21,12 @@
package org.nd4j.autodiff.internal; package org.nd4j.autodiff.internal;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.DependencyList; import org.nd4j.autodiff.samediff.internal.DependencyList;
import org.nd4j.autodiff.samediff.internal.DependencyTracker; import org.nd4j.autodiff.samediff.internal.DependencyTracker;
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker; import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,11 +37,8 @@ import java.util.Collections;
import static junit.framework.TestCase.assertNotNull; import static junit.framework.TestCase.assertNotNull;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class TestDependencyTracker extends BaseNd4jTest { public class TestDependencyTracker extends BaseNd4jTestWithBackends {
public TestDependencyTracker(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -47,7 +46,9 @@ public class TestDependencyTracker extends BaseNd4jTest {
} }
@Test @Test
public void testSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>(); DependencyTracker<String,String> dt = new DependencyTracker<>();
@ -94,7 +95,9 @@ public class TestDependencyTracker extends BaseNd4jTest {
} }
@Test @Test
public void testSatisfiedBeforeAdd(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSatisfiedBeforeAdd(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>(); DependencyTracker<String,String> dt = new DependencyTracker<>();
//Check different order of adding dependencies: i.e., mark X as satisfied, then add x -> y dependency //Check different order of adding dependencies: i.e., mark X as satisfied, then add x -> y dependency
@ -133,7 +136,9 @@ public class TestDependencyTracker extends BaseNd4jTest {
} }
@Test @Test
public void testMarkUnsatisfied(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMarkUnsatisfied(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>(); DependencyTracker<String,String> dt = new DependencyTracker<>();
dt.addDependency("y", "x"); dt.addDependency("y", "x");
@ -165,6 +170,8 @@ public class TestDependencyTracker extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIdentityDependencyTracker(){ public void testIdentityDependencyTracker(){
IdentityDependencyTracker<INDArray, String> dt = new IdentityDependencyTracker<>(); IdentityDependencyTracker<INDArray, String> dt = new IdentityDependencyTracker<>();
assertTrue(dt.isEmpty()); assertTrue(dt.isEmpty());

View File

@ -21,6 +21,8 @@
package org.nd4j.autodiff.opvalidation; package org.nd4j.autodiff.opvalidation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.GradCheckUtil; import org.nd4j.autodiff.validation.GradCheckUtil;
@ -38,12 +40,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class ActivationGradChecks extends BaseOpValidation { public class ActivationGradChecks extends BaseOpValidation {
public ActivationGradChecks(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testActivationGradientCheck1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testActivationGradientCheck1(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4)); SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4));
@ -61,7 +62,9 @@ public class ActivationGradChecks extends BaseOpValidation {
} }
@Test @Test
public void testActivationGradientCheck2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testActivationGradientCheck2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4); SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4);

View File

@ -21,18 +21,14 @@
package org.nd4j.autodiff.opvalidation; package org.nd4j.autodiff.opvalidation;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
public abstract class BaseOpValidation extends BaseNd4jTest { public abstract class BaseOpValidation extends BaseNd4jTestWithBackends {
private DataType initialType; private DataType initialType = Nd4j.dataType();
public BaseOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {

View File

@ -27,6 +27,8 @@ import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.OpValidation;
@ -65,9 +67,6 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class LayerOpValidation extends BaseOpValidation { public class LayerOpValidation extends BaseOpValidation {
public LayerOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
@ -75,7 +74,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testXwPlusB() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testXwPlusB(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -109,7 +110,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReluLayer() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReluLayer(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -137,7 +140,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBiasAdd() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -161,7 +166,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testConv2d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2d(Nd4jBackend backend) {
//avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling //avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling
//Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d //Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d
@ -301,7 +308,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLrn2d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLrn2d(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}}; int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};
@ -342,7 +351,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testIm2Col() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIm2Col(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873 //OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -381,7 +392,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testOutputShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOutputShape(Nd4jBackend backend) {
long[] inSize = {1, 8, 8, 3}; long[] inSize = {1, 8, 8, 3};
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -431,7 +444,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testAvgPool() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPool(Nd4jBackend backend) {
long[] inSize = {1, 8, 8, 3}; //NHWC long[] inSize = {1, 8, 8, 3}; //NHWC
Pooling2DConfig conf = Pooling2DConfig.builder() Pooling2DConfig conf = Pooling2DConfig.builder()
@ -474,7 +489,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testConv3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3d(Nd4jBackend backend) {
//Pooling3d, Conv3D, batch norm //Pooling3d, Conv3D, batch norm
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -576,7 +593,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testDepthWiseConv2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthWiseConv2dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int depthWise = 4; int depthWise = 4;
int kH = 2; int kH = 2;
@ -615,7 +634,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSeparableConv2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSeparableConv2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 2; int nIn = 2;
int nOut = 3; int nOut = 3;
@ -671,7 +692,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDeconv2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeconv2dBasic(Nd4jBackend backend) {
int nIn = 2; int nIn = 2;
int nOut = 3; int nOut = 3;
int kH = 2; int kH = 2;
@ -715,7 +738,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testConv2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
int kH = 2; int kH = 2;
@ -756,7 +781,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxPoolingArgMax() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPoolingArgMax(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
@ -785,7 +812,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxPooling2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPooling2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
@ -843,7 +872,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testAvgPooling2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPooling2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
@ -892,7 +923,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testAvgPooling3dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPooling3dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
int kW = 2; int kW = 2;
@ -929,7 +962,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxPooling3dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPooling3dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
int kW = 2; int kW = 2;
@ -967,7 +1002,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testConv1dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
int k = 2; int k = 2;
@ -1002,7 +1039,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testConv1dCausal() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dCausal(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
@ -1051,7 +1090,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testConv1dForward() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dForward(Nd4jBackend backend) {
int nIn = 2; int nIn = 2;
int nOut = 1; int nOut = 1;
int kernel = 3; int kernel = 3;
@ -1094,7 +1135,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testConv3dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
int kH = 2; int kH = 2;
@ -1140,7 +1183,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDeConv3dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv3dBasic(Nd4jBackend backend) {
int nIn = 4; int nIn = 4;
int nOut = 3; int nOut = 3;
int kH = 2; int kH = 2;
@ -1185,7 +1230,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNorm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNorm(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1210,7 +1257,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNorm4d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNorm4d(Nd4jBackend backend) {
int mb = 3; int mb = 3;
int ch = 4; int ch = 4;
for (boolean nchw : new boolean[]{true, false}) { for (boolean nchw : new boolean[]{true, false}) {
@ -1242,7 +1291,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testLayerNormOP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormOP(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1258,7 +1309,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNormNoBias() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormNoBias(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1281,7 +1334,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNormOPNoBias() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormOPNoBias(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1296,7 +1351,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNormNoDeviation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormNoDeviation(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
random.putScalar(1, i, 7); random.putScalar(1, i, 7);
@ -1326,7 +1383,7 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test() @Test()
public void exceptionThrown_WhenConv1DConfigInvalid() { public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
@ -1355,7 +1412,7 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test() @Test()
public void exceptionThrown_WhenConv2DConfigInvalid() { public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1378,7 +1435,7 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test() @Test()
public void exceptionThrown_WhenConf3DInvalid() { public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1411,7 +1468,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNormMixedOrders() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormMixedOrders(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f'); INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
@ -1458,7 +1517,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBiasAdd_nchw_nhwc() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd_nchw_nhwc(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for (boolean nchw : new boolean[]{true, false}) { for (boolean nchw : new boolean[]{true, false}) {
@ -1489,6 +1550,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthwiseConv2D(){ public void testDepthwiseConv2D(){
int bS = 10; int bS = 10;
@ -1527,7 +1590,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void LSTMLayerTestCase1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase1(Nd4jBackend backend) {
int bS = 5; int bS = 5;
int nIn = 3; int nIn = 3;
@ -1602,7 +1667,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void LSTMLayerTestCase2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase2(Nd4jBackend backend) {
int bS = 5; int bS = 5;
int nIn = 3; int nIn = 3;
int numUnits = 7; int numUnits = 7;
@ -1660,7 +1727,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void LSTMLayerTestCase3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase3(Nd4jBackend backend) {
int bS = 5; int bS = 5;
int nIn = 3; int nIn = 3;
int numUnits = 7; int numUnits = 7;
@ -1721,7 +1790,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void GRUTestCase() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void GRUTestCase(Nd4jBackend backend) {
int bS = 5; int bS = 5;
int nIn = 4; int nIn = 4;
int nOut = 6; int nOut = 6;

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -43,9 +45,7 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class LossOpValidation extends BaseOpValidation { public class LossOpValidation extends BaseOpValidation {
public LossOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
@ -56,7 +56,9 @@ public class LossOpValidation extends BaseOpValidation {
public static final Set<String> NO_BP_YET = new HashSet<>(); public static final Set<String> NO_BP_YET = new HashSet<>();
@Test @Test
public void testLoss2d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoss2d(Nd4jBackend backend) {
final List<String> oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax"); final List<String> oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax");
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -368,6 +370,8 @@ public class LossOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCosineDistance(){ public void testCosineDistance(){
INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}}); INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}});
INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}}); INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}});
@ -386,6 +390,8 @@ public class LossOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testL2Loss(){ public void testL2Loss(){
for( int rank=0; rank<=3; rank++ ){ for( int rank=0; rank<=3; rank++ ){
@ -428,7 +434,9 @@ public class LossOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNonZeroResult() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNonZeroResult(Nd4jBackend backend) {
INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5); INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5);
INDArray w = Nd4j.scalar(1.0); INDArray w = Nd4j.scalar(1.0);
INDArray label = Nd4j.rand(DataType.DOUBLE, 10, 5); INDArray label = Nd4j.rand(DataType.DOUBLE, 10, 5);
@ -486,6 +494,8 @@ public class LossOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void TestStdLossMixedDataType(){ public void TestStdLossMixedDataType(){
// Default Data Type in this test suite is Double. // Default Data Type in this test suite is Double.
// This test used to throw an Exception that we have mixed data types. // This test used to throw an Exception that we have mixed data types.

View File

@ -23,6 +23,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -78,13 +80,12 @@ import static org.junit.Assume.assumeNotNull;
@Slf4j @Slf4j
public class MiscOpValidation extends BaseOpValidation { public class MiscOpValidation extends BaseOpValidation {
public MiscOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testGradientAutoBroadcast1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast1(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -171,7 +172,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testGradientAutoBroadcast2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -260,7 +263,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testGradientAutoBroadcast3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast3(Nd4jBackend backend) {
//These tests: output size > input sizes //These tests: output size > input sizes
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -368,7 +373,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testScatterOpGradients() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterOpGradients(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
for (int i = 0; i < 7; i++) { for (int i = 0; i < 7; i++) {
@ -470,6 +477,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterUpdate(){ public void testScatterUpdate(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3); INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3);
INDArray updates = Nd4j.create(new float[][]{ INDArray updates = Nd4j.create(new float[][]{
@ -491,7 +500,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testGatherGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -542,6 +553,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrace(){ public void testTrace(){
//TODO need to work out how to handle shape_op for scalars... //TODO need to work out how to handle shape_op for scalars...
//OpValidationSuite.ignoreFailing(); //OpValidationSuite.ignoreFailing();
@ -567,7 +580,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
public void testTensorGradTensorMmul() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorGradTensorMmul(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -589,7 +604,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMulGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMulGradient(Nd4jBackend backend) {
INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
@ -654,22 +671,21 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
public void testMmulGradientManual() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulGradientManual(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
Map<String, INDArray> inputs = new HashMap<>(); Map<String, INDArray> inputs = new HashMap<>();
inputs.put("x", sumInput); inputs.put("x", sumInput);
inputs.put("y", sumInput.dup()); inputs.put("y", sumInput.dup());
sameDiff.defineFunction("mmulGradient", new SameDiffFunctionDefinition() { sameDiff.defineFunction("mmulGradient", (sameDiff1, inputs1, variableInputs) -> {
@Override SDVariable input = sameDiff1.var("x", inputs1.get("x"));
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) { SDVariable input2 = sameDiff1.var("y", inputs1.get("y"));
SDVariable input = sameDiff.var("x", inputs.get("x")); SDVariable exp = sameDiff1.mmul(input, input2);
SDVariable input2 = sameDiff.var("y", inputs.get("y")); SDVariable sum = sameDiff1.sum(exp, Integer.MAX_VALUE);
SDVariable exp = sameDiff.mmul(input, input2);
SDVariable sum = sameDiff.sum(exp, Integer.MAX_VALUE);
return new SDVariable[]{sum}; return new SDVariable[]{sum};
}
}, inputs); }, inputs);
@ -698,6 +714,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulGradients(){ public void testMmulGradients(){
int[] aShape = new int[]{2,3}; int[] aShape = new int[]{2,3};
int[] bShape = new int[]{3,4}; int[] bShape = new int[]{3,4};
@ -749,7 +767,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBatchMmulBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBatchMmulBasic(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873 OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873
int M = 5; int M = 5;
int N = 3; int N = 3;
@ -774,7 +794,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
public void testMmulWithTranspose() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulWithTranspose(Nd4jBackend backend) {
//Here: [x,3]^T * [x,4] = [3,4] //Here: [x,3]^T * [x,4] = [3,4]
@ -811,6 +833,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulOutputSizeCalculation(){ public void testMmulOutputSizeCalculation(){
//[3,2] x [2,4] with result transpose: output shape [4,3] //[3,2] x [2,4] with result transpose: output shape [4,3]
INDArray a = Nd4j.create(3,2); INDArray a = Nd4j.create(3,2);
@ -843,6 +867,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFillOp(){ public void testFillOp(){
INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT); INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT);
@ -857,6 +883,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm(){ public void testClipByNorm(){
//Expected: if array.norm2(1) is less than 1.0, not modified //Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -889,6 +917,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm2(){ public void testClipByNorm2(){
//Expected: if array.norm2(1) is less than 1.0, not modified //Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -932,6 +962,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm1(){ public void testClipByNorm1(){
//Expected: if array.norm2(1) is less than 1.0, not modified //Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -972,6 +1004,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm0(){ public void testClipByNorm0(){
//Expected: if array.norm2(0) is less than 1.0, not modified //Expected: if array.norm2(0) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -1001,6 +1035,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumSum(){ public void testCumSum(){
List<String> failing = new ArrayList<>(); List<String> failing = new ArrayList<>();
@ -1066,6 +1102,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumProd(){ public void testCumProd(){
List<String> failing = new ArrayList<>(); List<String> failing = new ArrayList<>();
@ -1134,6 +1172,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot1(){ public void testOneHot1(){
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -1164,6 +1204,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHotOp(){ public void testOneHotOp(){
//https://www.tensorflow.org/api_docs/python/tf/one_hot //https://www.tensorflow.org/api_docs/python/tf/one_hot
//https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp //https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp
@ -1178,7 +1220,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testOneHot2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot2(Nd4jBackend backend) {
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1); INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
@ -1198,7 +1242,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testOneHot4() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot4(Nd4jBackend backend) {
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1); INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
@ -1218,7 +1264,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testOneHot3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot3(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6872 //https://github.com/deeplearning4j/deeplearning4j/issues/6872
//https://www.tensorflow.org/api_docs/python/tf/one_hot //https://www.tensorflow.org/api_docs/python/tf/one_hot
@ -1253,6 +1301,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLinspace(){ public void testLinspace(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10); SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10);
@ -1266,6 +1316,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLinspace2(){ public void testLinspace2(){
OpValidationSuite.ignoreFailing(); //TODO 2019/01/18 OpValidationSuite.ignoreFailing(); //TODO 2019/01/18
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1280,7 +1332,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
public void testShapeFn() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapeFn(Nd4jBackend backend) {
INDArray in = Nd4j.create(new long[]{1, 2}); INDArray in = Nd4j.create(new long[]{1, 2});
@ -1294,7 +1348,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testShapeFn2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapeFn2(Nd4jBackend backend) {
INDArray i = Nd4j.create(1,3); INDArray i = Nd4j.create(1,3);
@ -1307,6 +1363,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeRank1(){ public void testMergeRank1(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5)); SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5));
@ -1325,7 +1383,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDiagPart() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagPart(Nd4jBackend backend) {
INDArray i = Nd4j.create(5,5); INDArray i = Nd4j.create(5,5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1337,7 +1397,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDiagShapeFn() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagShapeFn(Nd4jBackend backend) {
INDArray i = Nd4j.create(5,5); INDArray i = Nd4j.create(5,5);
CustomOp op = new DiagPart(i, null); CustomOp op = new DiagPart(i, null);
@ -1350,6 +1412,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZerosOnesLike(){ public void testZerosOnesLike(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1392,6 +1456,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZerosLikeOp(){ public void testZerosLikeOp(){
INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0); INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0);
@ -1407,6 +1473,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrix(){ public void testConfusionMatrix(){
DataType dt = DataType.DOUBLE; DataType dt = DataType.DOUBLE;
@ -1443,6 +1511,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIsNonDecreasingIsStrictlyIncr(){ public void testIsNonDecreasingIsStrictlyIncr(){
List<long[]> shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3}); List<long[]> shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3});
@ -1506,6 +1576,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExtractImagePatches(){ public void testExtractImagePatches(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -1553,6 +1625,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentProdBpSimple(){ public void testSegmentProdBpSimple(){
INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT); INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT);
@ -1573,6 +1647,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulRank4() throws Exception { public void testMmulRank4() throws Exception {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1608,6 +1684,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulRank4_simple(){ public void testMmulRank4_simple(){
INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64); INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
@ -1634,6 +1712,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNthElementRank1(){ public void testNthElementRank1(){
INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9}); INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9});
INDArray n = Nd4j.scalar(0); INDArray n = Nd4j.scalar(0);
@ -1656,6 +1736,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorMmulShape(){ public void testTensorMmulShape(){
INDArray a = Nd4j.create(new double[]{2}).reshape(1); INDArray a = Nd4j.create(new double[]{2}).reshape(1);
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2); INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
@ -1674,6 +1756,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorMmulShape2(){ public void testTensorMmulShape2(){
INDArray a = Nd4j.create(new double[]{2}).reshape(1); INDArray a = Nd4j.create(new double[]{2}).reshape(1);
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2); INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
@ -1682,6 +1766,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStopGradient(){ public void testStopGradient(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1701,6 +1787,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckNumerics(){ public void testCheckNumerics(){
OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927 OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927
@ -1744,7 +1832,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testCheckNumerics2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckNumerics2(Nd4jBackend backend) {
INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4); INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4);
INDArray msg = Nd4j.scalar("My error message!"); INDArray msg = Nd4j.scalar("My error message!");
@ -1757,6 +1847,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testHistogramFixedWidth(){ public void testHistogramFixedWidth(){
//Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf] //Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf]
INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9); INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9);
@ -1775,6 +1867,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition(){ public void testDynamicPartition(){
INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
@ -1793,6 +1887,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testListDiff(){ public void testListDiff(){
INDArray x = Nd4j.createFromArray(0, 1, 2, 3); INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
INDArray y = Nd4j.createFromArray(3, 1); INDArray y = Nd4j.createFromArray(3, 1);
@ -1812,7 +1908,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDivideNoNan() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDivideNoNan(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff() OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff()
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1836,7 +1934,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDigamma() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDigamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1851,7 +1951,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testFlatten() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFlatten(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1873,7 +1975,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testFusedBatchNorm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFusedBatchNorm(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1918,7 +2022,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testIgamma() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIgamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1934,7 +2040,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testIgammaC() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIgammaC(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1951,7 +2059,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLgamma() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLgamma(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1976,7 +2086,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLu() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLu(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2007,7 +2119,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMatrixBandPart() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixBandPart(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2037,7 +2151,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testPolygamma() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPolygamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -2053,7 +2169,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTriangularSolve() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriangularSolve(Nd4jBackend backend) {
INDArray a = Nd4j.createFromArray(new float[]{ INDArray a = Nd4j.createFromArray(new float[]{
3.f, 0.f, 0.f, 0.f, 3.f, 0.f, 0.f, 0.f,
@ -2077,7 +2195,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBiasAdd() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2106,7 +2226,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBiasAddGrad() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAddGrad(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2126,7 +2248,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testRoll() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRoll(Nd4jBackend backend) {
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}). 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}).
@ -2146,6 +2270,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSeqMask(){ public void testSeqMask(){
INDArray arr = Nd4j.createFromArray(1,2,3); INDArray arr = Nd4j.createFromArray(1,2,3);
INDArray maxLen = Nd4j.scalar(4); INDArray maxLen = Nd4j.scalar(4);

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -51,12 +53,11 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class RandomOpValidation extends BaseOpValidation { public class RandomOpValidation extends BaseOpValidation {
public RandomOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testRandomOpsSDVarShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomOpsSDVarShape(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -157,7 +158,9 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testRandomOpsLongShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomOpsLongShape(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
for (long[] shape : Arrays.asList(new long[]{1000}, new long[]{100, 10}, new long[]{40, 5, 5})) { for (long[] shape : Arrays.asList(new long[]{1000}, new long[]{100, 10}, new long[]{40, 5, 5})) {
@ -283,6 +286,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomBinomial(){ public void testRandomBinomial(){
INDArray z = Nd4j.create(new long[]{10}); INDArray z = Nd4j.create(new long[]{10});
@ -293,7 +298,9 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testUniformRankSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUniformRankSimple(Nd4jBackend backend) {
INDArray arr = Nd4j.createFromArray(new double[]{100.0}); INDArray arr = Nd4j.createFromArray(new double[]{100.0});
// OpTestCase tc = new OpTestCase(DynamicCustomOp.builder("randomuniform") // OpTestCase tc = new OpTestCase(DynamicCustomOp.builder("randomuniform")
@ -325,7 +332,9 @@ public class RandomOpValidation extends BaseOpValidation {
@Test @Test
public void testRandomExponential() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomExponential(Nd4jBackend backend) {
long length = 1_000_000; long length = 1_000_000;
INDArray shape = Nd4j.createFromArray(new double[]{length}); INDArray shape = Nd4j.createFromArray(new double[]{length});
INDArray out = Nd4j.createUninitialized(new long[]{length}); INDArray out = Nd4j.createUninitialized(new long[]{length});
@ -347,6 +356,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRange(){ public void testRange(){
//Technically deterministic, not random... //Technically deterministic, not random...
@ -380,6 +391,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllEmptyReduce(){ public void testAllEmptyReduce(){
INDArray x = Nd4j.createFromArray(true, true, true); INDArray x = Nd4j.createFromArray(true, true, true);
All all = new All(x); All all = new All(x);
@ -389,6 +402,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUniformDtype(){ public void testUniformDtype(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){ for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
@ -417,6 +432,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomExponential2(){ public void testRandomExponential2(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
DynamicCustomOp op = DynamicCustomOp.builder("random_exponential") DynamicCustomOp op = DynamicCustomOp.builder("random_exponential")

View File

@ -24,6 +24,8 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.validation.OpTestCase; import org.nd4j.autodiff.validation.OpTestCase;
import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -51,10 +53,6 @@ public class ReductionBpOpValidation extends BaseOpValidation {
private DataType initialType; private DataType initialType;
public ReductionBpOpValidation(Nd4jBackend backend) {
super(backend);
}
@BeforeEach @BeforeEach
public void before() { public void before() {
Nd4j.create(1); Nd4j.create(1);
@ -71,14 +69,16 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@AfterEach @AfterEach
public void tearDown() { public void tearDown(Nd4jBackend backend) {
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
} }
@Test @Test
public void testReduceSumBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumBP(Nd4jBackend backend) {
//Full array reduction //Full array reduction
//reduce_sum_bp op: has 2 inputs (original pre-reduce input, and gradient at output (epsilon)) //reduce_sum_bp op: has 2 inputs (original pre-reduce input, and gradient at output (epsilon))
@ -104,7 +104,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReduceSumAlongDim0BP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumAlongDim0BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -130,7 +132,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReduceSumAlongDim1BP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumAlongDim1BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -158,7 +162,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test @Test
public void testMeanBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanBP(Nd4jBackend backend) {
//dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j)) //dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j))
// = 1/N * dL/dOut // = 1/N * dL/dOut
@ -189,7 +195,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMeanBP_Rank1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanBP_Rank1(Nd4jBackend backend) {
INDArray dLdOut = Nd4j.scalar(0.5); INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5 / 3); INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5 / 3);
@ -202,7 +210,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMeanAlongDim0BP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanAlongDim0BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -230,7 +240,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMeanAlongDim1BP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanAlongDim1BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -258,7 +270,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test @Test
public void testMinBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMinBP(Nd4jBackend backend) {
//Full array min reduction //Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
@ -297,7 +311,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMinAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMinAlongDimensionBP(Nd4jBackend backend) {
//Full array min reduction //Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
@ -340,7 +356,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxBP(Nd4jBackend backend) {
//Full array max reduction //Full array max reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
@ -370,7 +388,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxAlongDimensionBP(Nd4jBackend backend) {
//Full array min reduction //Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
@ -413,7 +433,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testProdBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProdBP(Nd4jBackend backend) {
//Full array product reduction //Full array product reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
@ -442,7 +464,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testProdAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProdAlongDimensionBP(Nd4jBackend backend) {
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
// = dL/dOut * d(prod(in))/dIn_i // = dL/dOut * d(prod(in))/dIn_i
// = dL/dOut * (prod(in) / in_i) // = dL/dOut * (prod(in) / in_i)
@ -498,7 +522,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStdevBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevBP(Nd4jBackend backend) {
//If out = stdev(in) then: //If out = stdev(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = (in_i-mean)/(stdev * (n-1)) //dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
@ -534,7 +560,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStdevBP_Rank1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevBP_Rank1(Nd4jBackend backend) {
INDArray dLdOut = Nd4j.scalar(0.5); INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
double stdev = preReduceInput.stdNumber(true).doubleValue(); double stdev = preReduceInput.stdNumber(true).doubleValue();
@ -555,7 +583,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStdevAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevAlongDimensionBP(Nd4jBackend backend) {
//If out = stdev(in) then: //If out = stdev(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = (in_i-mean)/(stdev * (n-1)) //dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
@ -600,7 +630,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testVarianceBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVarianceBP(Nd4jBackend backend) {
//If out = variance(in) then: //If out = variance(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = 2*(in_i-mean)/(n-1) //dOut/dIn_i = 2*(in_i-mean)/(n-1)
@ -636,7 +668,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testVarianceAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVarianceAlongDimensionBP(Nd4jBackend backend) {
//If out = variance(in) then: //If out = variance(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = 2*(in_i-mean)/(n-1) //dOut/dIn_i = 2*(in_i-mean)/(n-1)
@ -678,7 +712,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test @Test
public void testCumSumBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumSumBP(Nd4jBackend backend) {
//Standard case, non-reverse, non-exclusive //Standard case, non-reverse, non-exclusive
//dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i //dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i
// = sum_j dL/dOut_j * d(in_0 + ... + in_j)/dIn_i // = sum_j dL/dOut_j * d(in_0 + ... + in_j)/dIn_i
@ -748,7 +784,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test @Test
public void testNorm2Bp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2Bp(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * x/|x|_2 // = dL/dOut * x/|x|_2
@ -775,7 +813,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNorm2AlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2AlongDimensionBP(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * x/|x|_2 // = dL/dOut * x/|x|_2
@ -808,7 +848,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNorm1Bp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1Bp(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * sgn(in) // = dL/dOut * sgn(in)
@ -835,7 +877,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNorm1AlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1AlongDimensionBP(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * sgn(in) // = dL/dOut * sgn(in)
@ -867,7 +911,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNormMaxBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMaxBp(Nd4jBackend backend) {
//out = max_i (|in_i|) //out = max_i (|in_i|)
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise) // = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise)
@ -897,7 +943,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNormMaxAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMaxAlongDimensionBP(Nd4jBackend backend) {
//out = max_i (|in_i|) //out = max_i (|in_i|)
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise) // = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise)

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -76,16 +77,13 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class ReductionOpValidation extends BaseOpValidation { public class ReductionOpValidation extends BaseOpValidation {
public ReductionOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testStdev() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdev(Nd4jBackend backend) {
List<String> errors = new ArrayList<>(); List<String> errors = new ArrayList<>();
for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE)) { for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE)) {
@ -111,7 +109,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testZeroCount() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeroCount(Nd4jBackend backend) {
List<String> allFailed = new ArrayList<>(); List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 21; i++) { for (int i = 0; i < 21; i++) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -145,7 +145,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test @Test
public void testZeroFraction() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeroFraction(Nd4jBackend backend) {
List<String> allFailed = new ArrayList<>(); List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -175,7 +177,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReductionGradientsSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradientsSimple(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES //OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
//Test reductions: final and only function //Test reductions: final and only function
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -344,7 +348,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReductionGradients1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradients1(Nd4jBackend backend) {
//Test reductions: final, but *not* the only function //Test reductions: final, but *not* the only function
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -472,7 +478,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReductionGradients2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradients2(Nd4jBackend backend) {
//Test reductions: NON-final function //Test reductions: NON-final function
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -650,7 +658,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test @Test
public void testReduce3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduce3(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int d0 = 3; int d0 = 3;
@ -755,7 +765,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMoments() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMoments(Nd4jBackend backend) {
for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) { for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) {
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -787,7 +799,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMomentsOp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMomentsOp(Nd4jBackend backend) {
int[] axes = new int[]{0}; int[] axes = new int[]{0};
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -804,7 +818,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNormalizeMomentsOp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormalizeMomentsOp(Nd4jBackend backend) {
INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10); INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
INDArray ssSum = data.sum(0); INDArray ssSum = data.sum(0);
INDArray ssSqSum = data.mul(data).sum(0); INDArray ssSqSum = data.mul(data).sum(0);
@ -824,7 +840,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testAllAny() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllAny(Nd4jBackend backend) {
INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4); INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4);
INDArray allOnes = Nd4j.ones(DataType.FLOAT, 3, 4); INDArray allOnes = Nd4j.ones(DataType.FLOAT, 3, 4);
@ -852,7 +870,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testIndexAccum() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexAccum(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/); List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/);
@ -941,7 +961,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test @Test
public void testReduce3_2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduce3_2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int d0 = 3; int d0 = 3;
@ -1039,7 +1061,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReductionsBackwards() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionsBackwards(Nd4jBackend backend) {
// for (int i = 0; i < 7; i++) { // for (int i = 0; i < 7; i++) {
int i=5; int i=5;
{ {
@ -1108,6 +1132,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttention(){ public void testDotProductAttention(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1133,6 +1159,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionWithMask(){ public void testDotProductAttentionWithMask(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1163,6 +1191,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionMultiHeadInputWithMask(){ public void testDotProductAttentionMultiHeadInputWithMask(){
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
@ -1194,6 +1224,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionMultiHeadInput(){ public void testDotProductAttentionMultiHeadInput(){
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
@ -1221,6 +1253,8 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiHeadedDotProductAttention(){ public void testMultiHeadedDotProductAttention(){
final INDArray k = Nd4j.rand(new int[]{10, 4, 5}); final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
final INDArray v = Nd4j.rand(new int[]{10, 4, 5}); final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
@ -1272,6 +1306,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionWeirdInputs(){ public void testDotProductAttentionWeirdInputs(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1309,6 +1345,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiHeadedDotProductAttentionWeirdInputs(){ public void testMultiHeadedDotProductAttentionWeirdInputs(){
final INDArray k = Nd4j.rand(new int[]{10, 4, 5}); final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
final INDArray v = Nd4j.rand(new int[]{10, 4, 5}); final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
@ -1366,7 +1404,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
} }
@Test @Test
public void testSufficientStatisticsOp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSufficientStatisticsOp(Nd4jBackend backend) {
INDArray data = Nd4j.createFromArray(new double[]{ INDArray data = Nd4j.createFromArray(new double[]{
5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1., 5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5
@ -1392,7 +1432,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStandardDeviation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardDeviation(Nd4jBackend backend) {
for (boolean keepDims : new boolean[]{false, true}) { for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1419,7 +1461,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSquaredNorm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSquaredNorm(Nd4jBackend backend) {
for (boolean keepDims : new boolean[]{false, true}) { for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1442,7 +1486,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testShannonEntropy() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShannonEntropy(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695 OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1462,7 +1508,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testEntropy() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEntropy(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1481,7 +1529,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testAMean() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAMean(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1502,7 +1552,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMean() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMean(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1523,7 +1575,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNorm1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1544,7 +1598,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNorm2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1565,7 +1621,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNormMax() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMax(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1586,7 +1644,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSoftmaxCrossEntropyWithLogitsLoss() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxCrossEntropyWithLogitsLoss(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -43,12 +45,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j @Slf4j
public class RnnOpValidation extends BaseOpValidation { public class RnnOpValidation extends BaseOpValidation {
public RnnOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testRnnBlockCell(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRnnBlockCell(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int mb = 2; int mb = 2;
int nIn = 3; int nIn = 3;
@ -147,7 +148,9 @@ public class RnnOpValidation extends BaseOpValidation {
@Test @Test
public void testRnnBlockCellManualTFCompare() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRnnBlockCellManualTFCompare(Nd4jBackend backend) {
//Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS" //Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS"
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -209,6 +212,8 @@ public class RnnOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGRUCell(){ public void testGRUCell(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int mb = 2; int mb = 2;

View File

@ -28,6 +28,8 @@ import lombok.val;
import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.LUDecomposition;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -67,9 +69,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
@Slf4j @Slf4j
public class ShapeOpValidation extends BaseOpValidation { public class ShapeOpValidation extends BaseOpValidation {
public ShapeOpValidation(Nd4jBackend backend) {
super(backend);
}
/* /*
To test: To test:
@ -83,7 +82,9 @@ public class ShapeOpValidation extends BaseOpValidation {
*/ */
@Test @Test
public void testConcat() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat(Nd4jBackend backend) {
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2}; // int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
int[] concatDim = new int[]{0, 0, 0}; int[] concatDim = new int[]{0, 0, 0};
List<List<int[]>> origShapes = new ArrayList<>(); List<List<int[]>> origShapes = new ArrayList<>();
@ -123,7 +124,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReshapeGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshapeGradient(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6873 //https://github.com/deeplearning4j/deeplearning4j/issues/6873
int[] origShape = new int[]{3, 4, 5}; int[] origShape = new int[]{3, 4, 5};
@ -159,7 +162,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testPermuteGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteGradient(Nd4jBackend backend) {
int[] origShape = new int[]{3, 4, 5}; int[] origShape = new int[]{3, 4, 5};
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -197,6 +202,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRank(){ public void testRank(){
List<long[]> inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5}); List<long[]> inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5});
@ -224,7 +231,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testExpandDimsGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExpandDimsGradient(Nd4jBackend backend) {
val origShape = new long[]{3, 4}; val origShape = new long[]{3, 4};
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -280,7 +289,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSqueezeGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSqueezeGradient(Nd4jBackend backend) {
val origShape = new long[]{3, 4, 5}; val origShape = new long[]{3, 4, 5};
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -344,7 +355,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testSliceGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSliceGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
//Order here: original shape, begin, size //Order here: original shape, begin, size
@ -434,7 +447,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSliceGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
//Order here: original shape, begin, size //Order here: original shape, begin, size
@ -497,7 +512,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testMerge() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMerge(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -573,7 +590,7 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test() @Test()
public void testStack() { public void testStack(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -664,7 +681,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testUnStack() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnStack(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -752,7 +771,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTile() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTile(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<int[]> tileArg = Arrays.asList( List<int[]> tileArg = Arrays.asList(
@ -824,6 +845,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTileBp(){ public void testTileBp(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -857,6 +880,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTileBp2(){ public void testTileBp2(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -891,7 +916,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testReshape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshape(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4); INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4);
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
@ -907,7 +934,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReshape2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshape2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int[] origShape = new int[]{3, 4, 5}; int[] origShape = new int[]{3, 4, 5};
@ -930,7 +959,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testTranspose() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTranspose(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
@ -942,6 +973,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTransposeOp(){ public void testTransposeOp(){
INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3); INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3);
@ -955,7 +988,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShape(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
val shape = new long[]{2, 3}; val shape = new long[]{2, 3};
SDVariable x = sameDiff.var("x", shape); SDVariable x = sameDiff.var("x", shape);
@ -970,7 +1005,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSize() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSize(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
val shape = new long[]{2, 3}; val shape = new long[]{2, 3};
SDVariable x = sameDiff.var("x", DataType.FLOAT, shape); SDVariable x = sameDiff.var("x", DataType.FLOAT, shape);
@ -984,7 +1021,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDiagShapeFn() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagShapeFn(Nd4jBackend backend) {
INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4); INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4);
OpTestCase op = new OpTestCase(new DiagPart(i, null)); OpTestCase op = new OpTestCase(new DiagPart(i, null));
@ -998,6 +1037,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute(){ public void testPermute(){
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5); INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
INDArray exp = in.permute(0,1,2); //No op INDArray exp = in.permute(0,1,2); //No op
@ -1012,6 +1053,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute2(){ public void testPermute2(){
for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) { for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) {
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5); INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
@ -1032,6 +1075,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConstant(){ public void testConstant(){
//OpValidationSuite.ignoreFailing(); //OpValidationSuite.ignoreFailing();
@ -1059,6 +1104,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnstackEdgeCase2(){ public void testUnstackEdgeCase2(){
for( int i=0; i<3; i++ ) { for( int i=0; i<3; i++ ) {
@ -1073,7 +1120,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void invertPermutation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void invertPermutation(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new float[] {3, 4, 0, 2, 1}).castTo(DataType.INT); INDArray ia = Nd4j.create(new float[] {3, 4, 0, 2, 1}).castTo(DataType.INT);
@ -1090,6 +1139,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherNd(){ public void testGatherNd(){
List<INDArray> indices = new ArrayList<>(); List<INDArray> indices = new ArrayList<>();
@ -1128,7 +1179,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testReverseSequence() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReverseSequence(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
float[] input_data = new float[]{ float[] input_data = new float[]{
1, 2, 3, 1, 2, 3,
@ -1174,6 +1227,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant(){ public void testMatrixDeterminant(){
OpValidationSuite.ignoreFailing(); //Gradient check failing OpValidationSuite.ignoreFailing(); //Gradient check failing
@ -1195,6 +1250,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeterminant22(){ public void testDeterminant22(){
OpValidationSuite.ignoreFailing(); //Gradient check failing OpValidationSuite.ignoreFailing(); //Gradient check failing
@ -1219,6 +1276,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant3(){ public void testMatrixDeterminant3(){
OpValidationSuite.ignoreFailing(); //Gradient checks failing OpValidationSuite.ignoreFailing(); //Gradient checks failing
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1250,6 +1309,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant4(){ public void testMatrixDeterminant4(){
OpValidationSuite.ignoreFailing(); //Gradient checks failing OpValidationSuite.ignoreFailing(); //Gradient checks failing
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1270,6 +1331,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentOps(){ public void testSegmentOps(){
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
//https://github.com/deeplearning4j/deeplearning4j/issues/6952 //https://github.com/deeplearning4j/deeplearning4j/issues/6952
@ -1362,6 +1425,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentMean(){ public void testSegmentMean(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3); INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3);
INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2); INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2);
@ -1382,7 +1447,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSequenceMask() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSequenceMask(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2}); INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
// arr is not trainable, so it's constant in model // arr is not trainable, so it's constant in model
@ -1416,6 +1483,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeshGrid(){ public void testMeshGrid(){
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -1472,6 +1541,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGather(){ public void testGather(){
List<INDArray> inArrs = new ArrayList<>(); List<INDArray> inArrs = new ArrayList<>();
List<Integer> axis = new ArrayList<>(); List<Integer> axis = new ArrayList<>();
@ -1541,7 +1612,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testGatherSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherSimple(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2}); INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2});
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
@ -1551,7 +1624,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testGatherNdSingle() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherNdSingle(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4);
INDArray arr2 = Nd4j.create(new float[]{1, 2, 3, 0, 1, 3, 1, 0, 2}, new long[]{3, 3}).castTo(DataType.INT); INDArray arr2 = Nd4j.create(new float[]{1, 2, 3, 0, 1, 3, 1, 0, 2}, new long[]{3, 3}).castTo(DataType.INT);
@ -1570,7 +1645,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStack2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStack2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2); INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
@ -1581,7 +1658,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testParallelStack() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testParallelStack(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2); INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
@ -1593,7 +1672,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testUnStack2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnStack2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Nd4j.zeros(3, 2); INDArray arr1 = Nd4j.zeros(3, 2);
INDArray arr2 = Nd4j.ones(3, 2); INDArray arr2 = Nd4j.ones(3, 2);
@ -1606,7 +1687,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testPermuteSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteSimple(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3)); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3));
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
@ -1617,7 +1700,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testConcat2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(4, 8, 4)).reshape(1,4); INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(4, 8, 4)).reshape(1,4);
@ -1628,7 +1713,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTile2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTile2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4)); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4));
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
@ -1641,7 +1728,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testSlice2d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice2d(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1657,7 +1746,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testSlice3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice3d(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1672,7 +1763,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSlice2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSlice2dBasic(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1690,7 +1783,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testStridedSliceBeginEndMask() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceBeginEndMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1705,7 +1800,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSliceEllipsisMask() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEllipsisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
@ -1722,7 +1819,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSliceNewAxisMask() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceNewAxisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
@ -1735,7 +1834,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSliceNewAxisMask2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceNewAxisMask2(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
@ -1746,7 +1847,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSliceShrinkAxisMask() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceShrinkAxisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1763,7 +1866,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSizeAt_1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSizeAt_1(Nd4jBackend backend) {
val array = Nd4j.create(10, 20, 30); val array = Nd4j.create(10, 20, 30);
val exp = Nd4j.scalar(DataType.LONG, 20); val exp = Nd4j.scalar(DataType.LONG, 20);
@ -1777,6 +1882,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(){ public void testEye(){
int[] rows = new int[]{3,3,3,3}; int[] rows = new int[]{3,3,3,3};
int[] cols = new int[]{3,2,2,2}; int[] cols = new int[]{3,2,2,2};
@ -1815,6 +1922,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplit1(){ public void testSplit1(){
INDArray in = Nd4j.linspace(1,10,10).reshape(10); INDArray in = Nd4j.linspace(1,10,10).reshape(10);
INDArray axis = Nd4j.scalar(-1); INDArray axis = Nd4j.scalar(-1);
@ -1833,6 +1942,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplit2(){ public void testSplit2(){
INDArray in = Nd4j.linspace(1,24,24).reshape(3,8); INDArray in = Nd4j.linspace(1,24,24).reshape(3,8);
INDArray axis = Nd4j.scalar(-1); INDArray axis = Nd4j.scalar(-1);
@ -1851,6 +1962,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDistancesExec(){ public void testDistancesExec(){
//https://github.com/deeplearning4j/deeplearning4j/issues/7001 //https://github.com/deeplearning4j/deeplearning4j/issues/7001
for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) { for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) {
@ -1906,6 +2019,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionShape(){ public void testReductionShape(){
INDArray shape = Nd4j.createFromArray(4,2); INDArray shape = Nd4j.createFromArray(4,2);
@ -1924,6 +2039,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void gatherTest(){ public void gatherTest(){
INDArray in = Nd4j.createFromArray(new double[][]{ INDArray in = Nd4j.createFromArray(new double[][]{
{1,2,3,4,5}, {1,2,3,4,5},
@ -1943,6 +2060,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSliceShape(){ public void testSliceShape(){
INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT); INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT);
@ -1964,6 +2083,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testWhereAllFalse(){ public void testWhereAllFalse(){
INDArray in = Nd4j.create(DataType.BOOL, 1917); INDArray in = Nd4j.create(DataType.BOOL, 1917);
DynamicCustomOp op = DynamicCustomOp.builder("Where") DynamicCustomOp op = DynamicCustomOp.builder("Where")
@ -1978,6 +2099,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherScalar(){ public void testGatherScalar(){
INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100); INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100);
INDArray indices = Nd4j.scalar(0); INDArray indices = Nd4j.scalar(0);
@ -2002,6 +2125,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCastEmpty(){ public void testCastEmpty(){
INDArray emptyLong = Nd4j.empty(DataType.LONG); INDArray emptyLong = Nd4j.empty(DataType.LONG);
int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h
@ -2018,6 +2143,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherEmpty(){ public void testGatherEmpty(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -2050,6 +2177,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplitEmpty(){ public void testSplitEmpty(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -2087,6 +2216,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatEmpty(){ public void testConcatEmpty(){
/* /*
TF behaviour with concatenatioun of empty arrays: TF behaviour with concatenatioun of empty arrays:
@ -2136,6 +2267,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatEmpty2(){ public void testConcatEmpty2(){
INDArray empty10a = Nd4j.create(DataType.INT, 1, 0); INDArray empty10a = Nd4j.create(DataType.INT, 1, 0);
INDArray empty10b = Nd4j.create(DataType.INT, 1, 0); INDArray empty10b = Nd4j.create(DataType.INT, 1, 0);
@ -2168,6 +2301,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyGather(){ public void testEmptyGather(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -2200,6 +2335,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastDynamicShape1(){ public void testBroadcastDynamicShape1(){
//Test case: [2,1] and [4]: expect [2,4] //Test case: [2,1] and [4]: expect [2,4]
@ -2221,6 +2358,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastDynamicShape2(){ public void testBroadcastDynamicShape2(){
//Test case: [2,1,4] and [2,2,4]: expect [2,2,4] //Test case: [2,1,4] and [2,2,4]: expect [2,2,4]
@ -2243,6 +2382,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceShrinkAxis(){ public void testStridedSliceShrinkAxis(){
INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2); INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2);
INDArray begin = Nd4j.createFromArray(2); INDArray begin = Nd4j.createFromArray(2);
@ -2268,6 +2409,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEmpty(){ public void testStridedSliceEmpty(){
INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask! INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask!
@ -2290,6 +2433,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEdgeCase(){ public void testStridedSliceEdgeCase(){
INDArray in = Nd4j.scalar(10).reshape(1); //Int [1] INDArray in = Nd4j.scalar(10).reshape(1); //Int [1]
INDArray begin = Nd4j.ones(DataType.INT, 1); INDArray begin = Nd4j.ones(DataType.INT, 1);
@ -2315,6 +2460,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptySlice1(){ public void testEmptySlice1(){
INDArray in = Nd4j.createFromArray(38); INDArray in = Nd4j.createFromArray(38);
INDArray begin = Nd4j.createFromArray(1); INDArray begin = Nd4j.createFromArray(1);
@ -2334,6 +2481,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptySlice2(){ public void testEmptySlice2(){
INDArray in = Nd4j.createFromArray(38); INDArray in = Nd4j.createFromArray(38);
INDArray begin = Nd4j.createFromArray(0); INDArray begin = Nd4j.createFromArray(0);
@ -2353,6 +2502,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFill(){ public void testFill(){
INDArray shape = Nd4j.createFromArray(0,4); INDArray shape = Nd4j.createFromArray(0,4);
@ -2372,6 +2523,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFill2(){ public void testFill2(){
INDArray shape = Nd4j.createFromArray(0,4); INDArray shape = Nd4j.createFromArray(0,4);
@ -2389,6 +2542,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteShapeDynamicAxis(){ public void testPermuteShapeDynamicAxis(){
DynamicCustomOp op = DynamicCustomOp.builder("permute") DynamicCustomOp op = DynamicCustomOp.builder("permute")
@ -2418,6 +2573,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGather2(){ public void testGather2(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3)); SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3));
@ -2437,6 +2594,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute3(){ public void testPermute3(){
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0); INDArray permute = Nd4j.createFromArray(1,0);
@ -2455,6 +2614,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute4(){ public void testPermute4(){
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0); INDArray permute = Nd4j.createFromArray(1,0);
@ -2485,6 +2646,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvertPermutation(){ public void testInvertPermutation(){
DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation") DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation")
.addInputs(Nd4j.createFromArray(1, 0)) .addInputs(Nd4j.createFromArray(1, 0))
@ -2492,7 +2655,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBroadcastInt1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastInt1(Nd4jBackend backend) {
INDArray out = Nd4j.create(DataType.INT, 1); INDArray out = Nd4j.create(DataType.INT, 1);
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape") DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
@ -2505,6 +2670,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastInt2(){ public void testBroadcastInt2(){
INDArray out = Nd4j.create(DataType.INT, 2); INDArray out = Nd4j.create(DataType.INT, 2);
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape") DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
@ -2544,7 +2711,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testMergeMaxIndex() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeMaxIndex(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2561,7 +2730,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTriOp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriOp(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable out = new Tri(sd, DataType.INT32, 3, 5, 2).outputVariable(); SDVariable out = new Tri(sd, DataType.INT32, 3, 5, 2).outputVariable();
@ -2573,7 +2744,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTriuOp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriuOp(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {7,8,9},{10,11,12}})); SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {7,8,9},{10,11,12}}));

View File

@ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -94,9 +96,6 @@ public class TransformOpValidation extends BaseOpValidation {
private DataType initialType; private DataType initialType;
public TransformOpValidation(Nd4jBackend backend) {
super(backend);
}
@BeforeEach @BeforeEach
public void before() { public void before() {
@ -120,7 +119,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testScalarOps() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarOps(Nd4jBackend backend) {
int d0 = 2; int d0 = 2;
int d1 = 3; int d1 = 3;
int d2 = 4; int d2 = 4;
@ -217,7 +218,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testScalarMulCF() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarMulCF(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
INDArray outC = Nd4j.createUninitialized(3, 4); INDArray outC = Nd4j.createUninitialized(3, 4);
@ -231,7 +234,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testScalarMulCF2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarMulCF2(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
@ -242,7 +247,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testCross() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCross(Nd4jBackend backend) {
INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3}); INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3});
INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3}); INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3});
@ -270,7 +277,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSpaceToDepth() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpaceToDepth(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
int miniBatch = 128; int miniBatch = 128;
@ -298,7 +307,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDepthToSpace() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthToSpace(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
int miniBatch = 128; int miniBatch = 128;
@ -325,7 +336,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBatchToSpace() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBatchToSpace(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863 //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
@ -362,7 +375,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSpaceToBatch() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpaceToBatch(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863 //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(7331); Nd4j.getRandom().setSeed(7331);
@ -400,7 +415,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDynamicPartition() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{4, 3, 5, 7, 8, 0}); INDArray ia = Nd4j.create(new double[]{4, 3, 5, 7, 8, 0});
@ -440,7 +457,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDynamicPartition2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition2(Nd4jBackend backend) {
INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
@ -458,7 +477,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDynamicStitch() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicStitch(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{5, 1, 3}, new long[]{3}); INDArray ia = Nd4j.create(new double[]{5, 1, 3}, new long[]{3});
@ -495,7 +516,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDiag() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiag(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{1, 2}, new int[]{2}); INDArray ia = Nd4j.create(new double[]{1, 2}, new int[]{2});
@ -521,7 +544,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDiagPart() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagPart(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray input = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4); INDArray input = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
@ -540,7 +565,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testEye() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(Nd4jBackend backend) {
int[] rows = new int[]{3, 3, 3, 3}; int[] rows = new int[]{3, 3, 3, 3};
int[] cols = new int[]{3, 2, 2, 2}; int[] cols = new int[]{3, 2, 2, 2};
int[][] batch = new int[][]{{}, {}, {4}, {3, 3}}; int[][] batch = new int[][]{{}, {}, {4}, {3, 3}};
@ -574,7 +601,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testEyeShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEyeShape(Nd4jBackend backend) {
DynamicCustomOp dco = DynamicCustomOp.builder("eye") DynamicCustomOp dco = DynamicCustomOp.builder("eye")
.addIntegerArguments(3, 3) .addIntegerArguments(3, 3)
//.addIntegerArguments(-99,3,3) //Also fails //.addIntegerArguments(-99,3,3) //Also fails
@ -586,7 +615,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTransforms() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTransforms(Nd4jBackend backend) {
//Test transforms (non-pairwise) //Test transforms (non-pairwise)
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1074,7 +1105,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testPairwiseTransforms() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPairwiseTransforms(Nd4jBackend backend) {
/* /*
add, sub, mul, div, rsub, rdiv add, sub, mul, div, rsub, rdiv
eq, neq, gt, lt, gte, lte, or, and, xor eq, neq, gt, lt, gte, lte, or, and, xor
@ -1258,7 +1291,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testIsX() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIsX(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
@ -1313,7 +1348,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReplaceWhereScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReplaceWhereScalar(Nd4jBackend backend) {
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
log.info("Testing condition: " + c.getClass().getSimpleName()); log.info("Testing condition: " + c.getClass().getSimpleName());
@ -1335,7 +1372,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReplaceWhereArray() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReplaceWhereArray(Nd4jBackend backend) {
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
INDArray inArr = Nd4j.rand(3, 4); INDArray inArr = Nd4j.rand(3, 4);
@ -1358,7 +1397,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLogGrad() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogGrad(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE)); SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE));
SDVariable log = sameDiff.math().log(input); SDVariable log = sameDiff.math().log(input);
@ -1369,7 +1410,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testSigmoidBackwards() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSigmoidBackwards(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
Map<String, INDArray> inputs = new HashMap<>(); Map<String, INDArray> inputs = new HashMap<>();
@ -1387,7 +1430,9 @@ public class TransformOpValidation extends BaseOpValidation {
/* @Test /* @Test
public void testDepth() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepth(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
SDVariable x = sameDiff.one("one",new long[]{2,2}); SDVariable x = sameDiff.one("one",new long[]{2,2});
assertEquals(0,x.depth()); assertEquals(0,x.depth());
@ -1396,7 +1441,9 @@ public class TransformOpValidation extends BaseOpValidation {
}*/ }*/
@Test @Test
public void testRank0EdgeCase() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRank0EdgeCase(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))); SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4})));
double d0 = v1.eval().getDouble(0); double d0 = v1.eval().getDouble(0);
@ -1409,7 +1456,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testAtan2BroadcastShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAtan2BroadcastShape(Nd4jBackend backend) {
INDArray arr1 = Nd4j.create(new long[]{3, 1, 4}); INDArray arr1 = Nd4j.create(new long[]{3, 1, 4});
INDArray arr2 = Nd4j.create(new long[]{1, 2, 4}); INDArray arr2 = Nd4j.create(new long[]{1, 2, 4});
@ -1424,7 +1473,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBooleanAnd() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBooleanAnd(Nd4jBackend backend) {
Nd4j.setDataType(DataType.FLOAT); Nd4j.setDataType(DataType.FLOAT);
INDArray arr1 = Nd4j.create(new long[]{3, 4}); INDArray arr1 = Nd4j.create(new long[]{3, 4});
INDArray arr2 = Nd4j.create(new long[]{3, 4}); INDArray arr2 = Nd4j.create(new long[]{3, 4});
@ -1438,7 +1489,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testScatterOpsScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterOpsScalar(Nd4jBackend backend) {
for (String s : new String[]{"add", "sub", "mul", "div"}) { for (String s : new String[]{"add", "sub", "mul", "div"}) {
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3); INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
INDArray indices = Nd4j.scalar(5); INDArray indices = Nd4j.scalar(5);
@ -1483,7 +1536,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540") @Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540")
@Test @Test
public void testPad() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPad(Nd4jBackend backend) {
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0); INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG); INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG);
INDArray value = Nd4j.scalar(10.0); INDArray value = Nd4j.scalar(10.0);
@ -1510,7 +1565,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testMirrorPad() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPad(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
@ -1543,7 +1600,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMirrorPad2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPad2(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
@ -1569,7 +1628,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMirrorPadSymmetric() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPadSymmetric(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4); INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT);
@ -1596,7 +1657,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testUnique() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnique(Nd4jBackend backend) {
INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4}); INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4});
INDArray expUnique = Nd4j.create(new double[]{3, 4, 1, 0, 2}); INDArray expUnique = Nd4j.create(new double[]{3, 4, 1, 0, 2});
@ -1618,7 +1681,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTopK() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopK(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //Can't assume sorted here OpValidationSuite.ignoreFailing(); //Can't assume sorted here
INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8}); INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8});
@ -1647,7 +1712,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTopK1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopK1(Nd4jBackend backend) {
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
INDArray k = Nd4j.scalar(1); INDArray k = Nd4j.scalar(1);
INDArray outValue = Nd4j.create(DataType.DOUBLE, 1); INDArray outValue = Nd4j.create(DataType.DOUBLE, 1);
@ -1668,7 +1735,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testInTopK() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInTopK(Nd4jBackend backend) {
for (int k = 4; k >= 1; k--) { for (int k = 4; k >= 1; k--) {
log.info("Testing: k=" + k); log.info("Testing: k=" + k);
INDArray in = Nd4j.linspace(1, 20, 20, DataType.DOUBLE).reshape(4, 5); INDArray in = Nd4j.linspace(1, 20, 20, DataType.DOUBLE).reshape(4, 5);
@ -1709,7 +1778,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testZeta() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeta(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182 OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182
INDArray x = Nd4j.rand(3, 4).addi(1.0); INDArray x = Nd4j.rand(3, 4).addi(1.0);
INDArray q = Nd4j.rand(3, 4); INDArray q = Nd4j.rand(3, 4);
@ -1726,7 +1797,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxEmptyScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxEmptyScalar(Nd4jBackend backend) {
INDArray empty = Nd4j.empty(DataType.FLOAT); INDArray empty = Nd4j.empty(DataType.FLOAT);
INDArray scalar = Nd4j.scalar(1.0f); INDArray scalar = Nd4j.scalar(1.0f);
@ -1743,7 +1816,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBroadcastEmpty() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastEmpty(Nd4jBackend backend) {
// Nd4j.getExecutioner().enableVerboseMode(true); // Nd4j.getExecutioner().enableVerboseMode(true);
// Nd4j.getExecutioner().enableDebugMode(true); // Nd4j.getExecutioner().enableDebugMode(true);
//Check broadcast behaviour with empty arrays. The idea is to match TF import behaviour, for import //Check broadcast behaviour with empty arrays. The idea is to match TF import behaviour, for import
@ -1833,7 +1908,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStandardize() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardize(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4}); final INDArray random = Nd4j.rand(new int[]{10, 4});
final int[] axis = new int[]{1}; final int[] axis = new int[]{1};
@ -1854,7 +1931,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStandardizeOP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardizeOP(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4}); final INDArray random = Nd4j.rand(new int[]{10, 4});
final int[] axis = new int[]{1}; final int[] axis = new int[]{1};
@ -1869,7 +1948,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStandardizeNoDeviation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardizeNoDeviation(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4}); final INDArray random = Nd4j.rand(new int[]{10, 4});
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
random.putScalar(1, i, 7); random.putScalar(1, i, 7);
@ -1895,7 +1976,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMatMulTensor() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatMulTensor(Nd4jBackend backend) {
final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5}); final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5});
final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6}); final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6});
@ -1915,7 +1998,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMatMulTensorTranspose() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatMulTensorTranspose(Nd4jBackend backend) {
for (boolean transposeA : new boolean[]{false, true}) { for (boolean transposeA : new boolean[]{false, true}) {
for (boolean transposeB : new boolean[]{false, true}) { for (boolean transposeB : new boolean[]{false, true}) {
for (boolean transposeResult : new boolean[]{false, true}) { for (boolean transposeResult : new boolean[]{false, true}) {
@ -2008,7 +2093,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSoftmaxCF() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxCF(Nd4jBackend backend) {
INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5); INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5);
INDArray arrF = arrC.dup('f'); INDArray arrF = arrC.dup('f');
@ -2029,7 +2116,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLogSumExp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogSumExp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4); INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2044,7 +2133,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLogSumExp2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogSumExp2(Nd4jBackend backend) {
for (int dim = 0; dim <= 2; dim++) { for (int dim = 0; dim <= 2; dim++) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -2065,7 +2156,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testCRELU() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCRELU(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2); INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2);
@ -2084,7 +2177,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testClipByAvgNorm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByAvgNorm(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2, 2); INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2, 2);
@ -2105,7 +2200,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testEmbeddingLookup() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmbeddingLookup(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2118,7 +2215,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testImageResize() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testImageResize(Nd4jBackend backend) {
//TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea //TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea
@ -2160,7 +2259,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testMaximumBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaximumBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2177,7 +2278,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMergeAddBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeAddBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2194,7 +2297,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMergeMaxBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeMaxBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2212,7 +2317,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testMergeAvgBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeAvgBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2229,7 +2336,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReverseBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReverseBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2243,7 +2352,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testUpsampling3dBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUpsampling3dBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for (boolean dataformat : new boolean[]{true, false}) { for (boolean dataformat : new boolean[]{true, false}) {

View File

@ -24,8 +24,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
@ -36,11 +37,8 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
public class ConvConfigTests extends BaseNd4jTest { public class ConvConfigTests extends BaseNd4jTestWithBackends {
public ConvConfigTests(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -48,7 +46,9 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
@Test @Test
public void testDeConv2D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv2D(Nd4jBackend backend){
DeConv2DConfig.builder().kH(2).kW(4).build(); DeConv2DConfig.builder().kH(2).kW(4).build();
try{ try{
@ -109,7 +109,9 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
@Test @Test
public void testConv2D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2D(Nd4jBackend backend){
Conv2DConfig.builder().kH(2).kW(4).build(); Conv2DConfig.builder().kH(2).kW(4).build();
try{ try{
@ -170,7 +172,9 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
@Test @Test
public void testPooling2D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPooling2D(Nd4jBackend backend){
Pooling2DConfig.builder().kH(2).kW(4).build(); Pooling2DConfig.builder().kH(2).kW(4).build();
try{ try{
@ -231,7 +235,9 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
@Test @Test
public void testDeConv3D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv3D(Nd4jBackend backend){
DeConv3DConfig.builder().kH(2).kW(4).kD(3).build(); DeConv3DConfig.builder().kH(2).kW(4).kD(3).build();
try{ try{
@ -320,7 +326,9 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
@Test @Test
public void testConv3D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3D(Nd4jBackend backend){
Conv3DConfig.builder().kH(2).kW(4).kD(3).build(); Conv3DConfig.builder().kH(2).kW(4).kD(3).build();
try{ try{
@ -411,7 +419,9 @@ public class ConvConfigTests extends BaseNd4jTest {
@Test @Test
public void testPooling3D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPooling3D(Nd4jBackend backend){
Pooling3DConfig.builder().kH(2).kW(4).kD(3).build(); Pooling3DConfig.builder().kH(2).kW(4).kD(3).build();
try{ try{
@ -500,6 +510,8 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1D(){ public void testConv1D(){
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build(); Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();

View File

@ -23,8 +23,10 @@ package org.nd4j.autodiff.samediff;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Disabled("AB 2019/05/21 - JVM Crash on ppc64 - Issue #7657") @Disabled("AB 2019/05/21 - JVM Crash on ppc64 - Issue #7657")
public class FailingSameDiffTests extends BaseNd4jTest { public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
public FailingSameDiffTests(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -52,7 +51,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
} }
@Test @Test
public void testEye(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(Nd4jBackend backend){
//OpValidationSuite.ignoreFailing(); //OpValidationSuite.ignoreFailing();
INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3}); INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3});
List<INDArray> stack = new ArrayList<>(); List<INDArray> stack = new ArrayList<>();
@ -68,7 +69,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
} }
@Test @Test
public void testEyeShape(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEyeShape(Nd4jBackend backend){
val dco = DynamicCustomOp.builder("eye") val dco = DynamicCustomOp.builder("eye")
.addIntegerArguments(3,3) .addIntegerArguments(3,3)
//.addIntegerArguments(-99,3,3) //Also fails //.addIntegerArguments(-99,3,3) //Also fails
@ -80,7 +83,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
} }
@Test @Test
public void testExecutionDifferentShapesTransform(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecutionDifferentShapesTransform(Nd4jBackend backend){
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4)); SDVariable in = sd.var("in", Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4));
@ -101,7 +106,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
} }
@Test @Test
public void testDropout() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDropout(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
double p = 0.5; double p = 0.5;
@ -114,7 +121,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
} }
@Test @Test
public void testExecutionDifferentShapesDynamicCustom(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecutionDifferentShapesDynamicCustom(Nd4jBackend backend){
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();

View File

@ -26,13 +26,15 @@ import org.apache.commons.io.IOUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.graph.FlatConfiguration; import org.nd4j.graph.FlatConfiguration;
import org.nd4j.graph.FlatGraph; import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode; import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable; import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair; import org.nd4j.graph.IntPair;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
@ -70,11 +72,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
public class FlatBufferSerdeTest extends BaseNd4jTest { public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
public FlatBufferSerdeTest(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -84,7 +83,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test @Test
public void testBasic(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4); INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() ); SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
@ -139,7 +140,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
} }
@Test @Test
public void testSimple(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
for( int i = 0; i < 10; i++ ) { for( int i = 0; i < 10; i++ ) {
for(boolean execFirst : new boolean[]{false, true}) { for(boolean execFirst : new boolean[]{false, true}) {
log.info("Starting test: i={}, execFirst={}", i, execFirst); log.info("Starting test: i={}, execFirst={}", i, execFirst);
@ -268,7 +271,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test @Test
public void testTrainingSerde(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
//Ensure 2 things: //Ensure 2 things:
//1. Training config is serialized/deserialized correctly //1. Training config is serialized/deserialized correctly
@ -352,7 +357,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test @Test
public void pooling3DSerialization(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void pooling3DSerialization(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28); SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28);
@ -372,7 +379,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
} }
@Test @Test
public void pooling3DSerialization2(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void pooling3DSerialization2(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28); SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28);

View File

@ -22,12 +22,14 @@ package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil; import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
import org.nd4j.autodiff.samediff.transform.OpPredicate; import org.nd4j.autodiff.samediff.transform.OpPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraph; import org.nd4j.autodiff.samediff.transform.SubGraph;
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate; import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor; import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
public class GraphTransformUtilTests extends BaseNd4jTest { public class GraphTransformUtilTests extends BaseNd4jTestWithBackends {
public GraphTransformUtilTests(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -54,7 +53,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
} }
@Test @Test
public void testBasic(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasic(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 32); SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 32);
@ -93,7 +94,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
} }
@Test @Test
public void testSubgraphReplace1(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSubgraphReplace1(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 4); SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 4);

View File

@ -21,8 +21,10 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr; import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -32,11 +34,8 @@ import java.lang.reflect.Field;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class MemoryMgrTest extends BaseNd4jTest { public class MemoryMgrTest extends BaseNd4jTestWithBackends {
public MemoryMgrTest(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -44,7 +43,9 @@ public class MemoryMgrTest extends BaseNd4jTest {
} }
@Test @Test
public void testArrayReuseTooLarge() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArrayReuseTooLarge(Nd4jBackend backend) throws Exception {
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr(); ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
Field f = ArrayCacheMemoryMgr.class.getDeclaredField("maxCacheBytes"); Field f = ArrayCacheMemoryMgr.class.getDeclaredField("maxCacheBytes");
@ -97,7 +98,7 @@ public class MemoryMgrTest extends BaseNd4jTest {
assertEquals(10, mmgr.getLruCacheValues().size()); assertEquals(10, mmgr.getLruCacheValues().size());
//now, allocate some values: //now, allocate some values:
for( int i=1; i<=10; i++ ) { for( int i = 1; i <= 10; i++) {
INDArray a1 = mmgr.allocate(true, DataType.FLOAT, 25); INDArray a1 = mmgr.allocate(true, DataType.FLOAT, 25);
assertEquals(1000 - i * 100, mmgr.getCurrentCacheSize()); assertEquals(1000 - i * 100, mmgr.getCurrentCacheSize());
assertEquals(1000 - i * 100, as.getBytesSum()); assertEquals(1000 - i * 100, as.getBytesSum());
@ -116,10 +117,12 @@ public class MemoryMgrTest extends BaseNd4jTest {
} }
@Test @Test
public void testManyArrays(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testManyArrays(Nd4jBackend backend){
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr(); ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
for( int i=0; i<1000; i++ ){ for( int i = 0; i < 1000; i++) {
mmgr.release(Nd4j.scalar(0)); mmgr.release(Nd4j.scalar(0));
} }
@ -127,7 +130,7 @@ public class MemoryMgrTest extends BaseNd4jTest {
assertEquals(1000, mmgr.getLruCache().size()); assertEquals(1000, mmgr.getLruCache().size());
assertEquals(1000, mmgr.getLruCacheValues().size()); assertEquals(1000, mmgr.getLruCacheValues().size());
for( int i=0; i<1000; i++ ){ for( int i = 0; i < 1000; i++ ){
mmgr.release(Nd4j.scalar(0)); mmgr.release(Nd4j.scalar(0));
} }

View File

@ -21,9 +21,11 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,19 +37,18 @@ import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class NameScopeTests extends BaseNd4jTest { public class NameScopeTests extends BaseNd4jTestWithBackends {
public NameScopeTests(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering() {
return 'c'; return 'c';
} }
@Test @Test
public void testVariableNameScopesBasic(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVariableNameScopesBasic(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable v = sd.var("x"); SDVariable v = sd.var("x");
@ -73,7 +74,9 @@ public class NameScopeTests extends BaseNd4jTest {
} }
@Test @Test
public void testOpFieldsAndNames(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpFieldsAndNames(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable x = sd.var("x", DataType.FLOAT, 1); SDVariable x = sd.var("x", DataType.FLOAT, 1);
@ -151,7 +154,9 @@ public class NameScopeTests extends BaseNd4jTest {
} }
@Test @Test
public void testNoNesting(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoNesting(Nd4jBackend backend) {
SameDiff SD = SameDiff.create(); SameDiff SD = SameDiff.create();
SDVariable a = SD.constant(4); SDVariable a = SD.constant(4);
@ -168,7 +173,9 @@ public class NameScopeTests extends BaseNd4jTest {
} }
@Test @Test
public void testNoTesting2(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoTesting2(Nd4jBackend backend) {
SameDiff SD = SameDiff.create(); SameDiff SD = SameDiff.create();
SDVariable a = SD.constant(4); SDVariable a = SD.constant(4);

View File

@ -21,21 +21,16 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.imports.tfgraphs.TFGraphTestZooModels;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.AtomicBoolean; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.resources.Resources;
import java.io.File;
import java.nio.file.Path;
import java.util.Collections; import java.util.Collections;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore; import java.util.concurrent.Semaphore;
@ -55,7 +50,9 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
} }
@Test @Test
public void testSimple() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(Nd4jBackend backend) throws Exception {
int nThreads = 4; int nThreads = 4;
int nRuns = 1000; int nRuns = 1000;
@ -103,48 +100,6 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
} }
} }
@Test
@Disabled //2020/03/24 AB - https://github.com/eclipse/deeplearning4j/issues/8802
public void testMobilenet(@TempDir Path testDir) throws Exception {
TFGraphTestZooModels.currentTestDir = testDir.toFile();
File f = Resources.asFile("tf_graphs/zoo_models/mobilenet_v2_1.0_224/tf_model.txt");
SameDiff sd = TFGraphTestZooModels.LOADER.apply(f, "mobilenet_v2_1.0_224");
// System.out.println(sd.summary());
int nThreads = 4;
int nRuns = 30;
INDArray[] inputArrs = new INDArray[nThreads];
INDArray[] expOut = new INDArray[nThreads];
for( int i=0; i<nThreads; i++ ){
if(i == 0 || i > 2)
inputArrs[i] = Nd4j.rand(DataType.FLOAT, 1, 224, 224, 3);
else if(i == 1)
inputArrs[i] = Nd4j.zeros(DataType.FLOAT, 1, 224, 224, 3);
else if(i == 2)
inputArrs[i] = Nd4j.ones(DataType.FLOAT, 1, 224, 224, 3);
expOut[i] = sd.outputSingle(Collections.singletonMap("input", inputArrs[i]), "MobilenetV2/Predictions/Reshape_1");
Nd4j.getExecutioner().commit();
}
AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads];
AtomicInteger[] counters = new AtomicInteger[nThreads];
Semaphore s = new Semaphore(nThreads);
CountDownLatch latch = new CountDownLatch(nThreads);
doTest(sd, nThreads, nRuns, inputArrs, expOut, "input", "MobilenetV2/Predictions/Reshape_1", failuresByThread, counters, s, latch);
s.release(nThreads);
latch.await();
for(int i = 0; i < nThreads; i++) {
assertFalse( failuresByThread[i].get(),"Thread " + i + " failed");
}
for(int i = 0; i < nThreads; i++) {
assertEquals( nRuns, counters[i].get(),"Thread " + i + " number of runs");
}
}
public static void doTest(SameDiff sd, int nThreads, int nRuns, INDArray[] inputArrs, INDArray[] expOut, public static void doTest(SameDiff sd, int nThreads, int nRuns, INDArray[] inputArrs, INDArray[] expOut,

View File

@ -21,7 +21,9 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -31,14 +33,13 @@ import org.nd4j.linalg.learning.config.Sgd;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class SameDiffOutputTest extends BaseNd4jTest { public class SameDiffOutputTest extends BaseNd4jTestWithBackends {
public SameDiffOutputTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void outputTest(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void outputTest(Nd4jBackend backend){
DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10)); DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10));
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();

View File

@ -21,7 +21,9 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -35,19 +37,18 @@ import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.assertNull; import static junit.framework.TestCase.assertNull;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest { public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
public SameDiffSpecifiedLossVarsTests(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering() {
return 'c'; return 'c';
} }
@Test @Test
public void testSpecifiedLoss1(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedLoss1(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4); SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4);
ph1.setArray(Nd4j.create(DataType.FLOAT, 3, 4)); ph1.setArray(Nd4j.create(DataType.FLOAT, 3, 4));
@ -68,7 +69,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
} }
@Test @Test
public void testSpecifiedLoss2(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedLoss2(Nd4jBackend backend) {
for( int i=0; i<2; i++ ) { for( int i=0; i<2; i++ ) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4); SDVariable ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4);
@ -121,7 +124,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
@Test @Test
public void testTrainingDifferentLosses(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingDifferentLosses(Nd4jBackend backend) {
//Net with 2 losses: train on the first one, then change losses //Net with 2 losses: train on the first one, then change losses
//Also check that if modifying via add/setLossVariables the training config changes //Also check that if modifying via add/setLossVariables the training config changes

View File

@ -30,11 +30,13 @@ import java.util.Map;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.impl.ScoreListener; import org.nd4j.autodiff.listeners.impl.ScoreListener;
import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -55,14 +57,13 @@ import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.weightinit.impl.XavierInitScheme; import org.nd4j.weightinit.impl.XavierInitScheme;
@Slf4j @Slf4j
public class SameDiffTrainingTest extends BaseNd4jTest { public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
public SameDiffTrainingTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void irisTrainingSanityCheck() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingSanityCheck(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize(); NormalizerStandardize std = new NormalizerStandardize();
@ -134,7 +135,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test @Test
public void irisTrainingEvalTest() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingEvalTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize(); NormalizerStandardize std = new NormalizerStandardize();
@ -184,7 +187,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test @Test
public void irisTrainingValidationTest() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingValidationTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize(); NormalizerStandardize std = new NormalizerStandardize();
@ -239,6 +244,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingMixedDtypes(){ public void testTrainingMixedDtypes(){
for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) { for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) {
@ -301,7 +308,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
} }
@Test @Test
public void simpleClassification() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void simpleClassification(Nd4jBackend backend) {
double learning_rate = 0.001; double learning_rate = 0.001;
int seed = 7; int seed = 7;
org.nd4j.linalg.api.rng.Random rng = Nd4j.getRandom(); org.nd4j.linalg.api.rng.Random rng = Nd4j.getRandom();
@ -348,6 +357,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingEvalVarNotReqForLoss(){ public void testTrainingEvalVarNotReqForLoss(){
//If a variable is not required for the loss - normally it won't be calculated //If a variable is not required for the loss - normally it won't be calculated
//But we want to make sure it IS calculated here - so we can perform evaluation on it //But we want to make sure it IS calculated here - so we can perform evaluation on it

View File

@ -25,11 +25,13 @@ import org.junit.Assert;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener; import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.dataset.IrisDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -48,11 +50,8 @@ import java.util.concurrent.TimeUnit;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class CheckpointListenerTest extends BaseNd4jTest { public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
public CheckpointListenerTest(Nd4jBackend backend){
super(backend);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -96,7 +95,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
@Test @Test
public void testCheckpointEveryEpoch(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();
@ -130,7 +131,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void testCheckpointEvery5Iter(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();
@ -169,7 +172,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
@Test @Test
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();
@ -199,7 +204,7 @@ public class CheckpointListenerTest extends BaseNd4jTest {
for(File f : files){ for(File f : files){
String s = f.getAbsolutePath(); String s = f.getAbsolutePath();
// System.out.println(s); // System.out.println(s);
for( int i=0; i<names.size(); i++ ){ for( int i = 0; i < names.size(); i++ ){
if(s.contains(names.get(i))){ if(s.contains(names.get(i))){
found[i] = true; found[i] = true;
break; break;
@ -213,7 +218,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();

View File

@ -21,12 +21,13 @@
package org.nd4j.autodiff.samediff.listeners; package org.nd4j.autodiff.samediff.listeners;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener; import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -34,14 +35,13 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
public class ExecDebuggingListenerTest extends BaseNd4jTest { public class ExecDebuggingListenerTest extends BaseNd4jTestWithBackends {
public ExecDebuggingListenerTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testExecDebugListener(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecDebugListener(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);

View File

@ -21,6 +21,8 @@
package org.nd4j.autodiff.samediff.listeners; package org.nd4j.autodiff.samediff.listeners;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener; import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.listeners.Listener;
@ -38,7 +40,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.Evaluation.Metric; import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.OpContext;
@ -61,11 +63,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class ListenerTest extends BaseNd4jTest { public class ListenerTest extends BaseNd4jTestWithBackends {
public ListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -73,7 +72,9 @@ public class ListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void irisHistoryTest() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisHistoryTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize(); NormalizerStandardize std = new NormalizerStandardize();
@ -136,6 +137,8 @@ public class ListenerTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testListenerCalls(){ public void testListenerCalls(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
@ -273,7 +276,9 @@ public class ListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void testCustomListener() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCustomListener(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4); SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);

View File

@ -26,11 +26,13 @@ import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.profiler.ProfilingListener; import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer; import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -45,11 +47,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
public class ProfilingListenerTest extends BaseNd4jTest { public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
public ProfilingListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -59,7 +58,9 @@ public class ProfilingListenerTest extends BaseNd4jTest {
@Test @Test
public void testProfilingListenerSimple(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
@ -108,18 +109,24 @@ public class ProfilingListenerTest extends BaseNd4jTest {
/* /*
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfile(){ public void testLoadTfProfile(){
File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json"); File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json");
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfileDir(){ public void testLoadTfProfileDir(){
File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles"); File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfileDir2(){ public void testLoadTfProfileDir2(){
File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0"); File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);

View File

@ -27,6 +27,8 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.VariableType;
@ -41,7 +43,7 @@ import org.nd4j.graph.UIInfoType;
import org.nd4j.graph.UIOp; import org.nd4j.graph.UIOp;
import org.nd4j.graph.UIVariable; import org.nd4j.graph.UIVariable;
import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -60,11 +62,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
public class FileReadWriteTests extends BaseNd4jTest { public class FileReadWriteTests extends BaseNd4jTestWithBackends {
public FileReadWriteTests(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -81,7 +80,9 @@ public class FileReadWriteTests extends BaseNd4jTest {
} }
@Test @Test
public void testSimple(@TempDir Path testDir) throws IOException { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4); SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
SDVariable sum = v.sum(); SDVariable sum = v.sum();
@ -185,7 +186,9 @@ public class FileReadWriteTests extends BaseNd4jTest {
} }
@Test @Test
public void testNullBinLabels(@TempDir Path testDir) throws Exception{ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
File dir = testDir.toFile(); File dir = testDir.toFile();
File f = new File(dir, "temp.bin"); File f = new File(dir, "temp.bin");
LogFileWriter w = new LogFileWriter(f); LogFileWriter w = new LogFileWriter(f);

View File

@ -25,6 +25,8 @@ import com.google.flatbuffers.Table;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.impl.UIListener; import org.nd4j.autodiff.listeners.impl.UIListener;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -34,7 +36,7 @@ import org.nd4j.graph.UIEvent;
import org.nd4j.graph.UIGraphStructure; import org.nd4j.graph.UIGraphStructure;
import org.nd4j.graph.UIStaticInfoRecord; import org.nd4j.graph.UIStaticInfoRecord;
import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.dataset.IrisDataSetIterator;
@ -51,11 +53,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class UIListenerTest extends BaseNd4jTest { public class UIListenerTest extends BaseNd4jTestWithBackends {
public UIListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -65,7 +64,9 @@ public class UIListenerTest extends BaseNd4jTest {
@Test @Test
public void testUIListenerBasic(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -101,7 +102,9 @@ public class UIListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void testUIListenerContinue(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet(); SameDiff sd1 = getSimpleNet();
@ -192,7 +195,9 @@ public class UIListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void testUIListenerBadContinue(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet(); SameDiff sd1 = getSimpleNet();

View File

@ -23,18 +23,17 @@ package org.nd4j.evaluation;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.custom.CustomEvaluation; import org.nd4j.evaluation.custom.CustomEvaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
public class CustomEvaluationTest extends BaseNd4jTest { public class CustomEvaluationTest extends BaseNd4jTestWithBackends {
public CustomEvaluationTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -42,8 +41,10 @@ public class CustomEvaluationTest extends BaseNd4jTest {
} }
@Test @Test
public void customEvalTest(){ @ParameterizedTest
CustomEvaluation accuracyEval = new CustomEvaluation<Pair<Number, Long>>( @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void customEvalTest(Nd4jBackend backend){
CustomEvaluation accuracyEval = new CustomEvaluation<>(
(labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)), (labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)),
CustomEvaluation.mergeConcatenate()); CustomEvaluation.mergeConcatenate());

View File

@ -21,6 +21,8 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -29,17 +31,14 @@ import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
public class EmptyEvaluationTests extends BaseNd4jTest { public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
public EmptyEvaluationTests(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -47,7 +46,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyEvaluation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluation (Nd4jBackend backend) {
Evaluation e = new Evaluation(); Evaluation e = new Evaluation();
System.out.println(e.stats()); System.out.println(e.stats());
@ -62,7 +63,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyRegressionEvaluation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyRegressionEvaluation (Nd4jBackend backend) {
RegressionEvaluation re = new RegressionEvaluation(); RegressionEvaluation re = new RegressionEvaluation();
re.stats(); re.stats();
@ -76,7 +79,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyEvaluationBinary() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluationBinary(Nd4jBackend backend) {
EvaluationBinary eb = new EvaluationBinary(); EvaluationBinary eb = new EvaluationBinary();
eb.stats(); eb.stats();
@ -91,7 +96,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyROC() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROC(Nd4jBackend backend) {
ROC roc = new ROC(); ROC roc = new ROC();
roc.stats(); roc.stats();
@ -106,7 +113,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyROCBinary() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROCBinary(Nd4jBackend backend) {
ROCBinary rb = new ROCBinary(); ROCBinary rb = new ROCBinary();
rb.stats(); rb.stats();
@ -121,7 +130,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyROCMultiClass() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROCMultiClass(Nd4jBackend backend) {
ROCMultiClass r = new ROCMultiClass(); ROCMultiClass r = new ROCMultiClass();
r.stats(); r.stats();
@ -136,7 +147,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyEvaluationCalibration() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluationCalibration(Nd4jBackend backend) {
EvaluationCalibration ec = new EvaluationCalibration(); EvaluationCalibration ec = new EvaluationCalibration();
ec.stats(); ec.stats();

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin; import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
@ -36,11 +38,8 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class EvalCustomThreshold extends BaseNd4jTest { public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
public EvalCustomThreshold(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -48,7 +47,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationCustomBinaryThreshold() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCustomBinaryThreshold(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
//Sanity checks: 0.5 threshold for 1-output and 2-output binary cases //Sanity checks: 0.5 threshold for 1-output and 2-output binary cases
@ -114,7 +115,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationCostArray() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCostArray(Nd4jBackend backend) {
int nExamples = 20; int nExamples = 20;
@ -162,7 +165,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinaryCustomThreshold() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryCustomThreshold(Nd4jBackend backend) {
//Sanity check: same results for 0.5 threshold vs. default (no threshold) //Sanity check: same results for 0.5 threshold vs. default (no threshold)
int nExamples = 20; int nExamples = 20;

View File

@ -21,6 +21,8 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -31,7 +33,7 @@ import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class EvalJsonTest extends BaseNd4jTest { public class EvalJsonTest extends BaseNd4jTestWithBackends {
public EvalJsonTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -54,7 +53,9 @@ public class EvalJsonTest extends BaseNd4jTest {
} }
@Test @Test
public void testSerdeEmpty() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerdeEmpty(Nd4jBackend backend) {
boolean print = false; boolean print = false;
IEvaluation[] arr = new IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10), IEvaluation[] arr = new IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10),
@ -74,7 +75,9 @@ public class EvalJsonTest extends BaseNd4jTest {
} }
@Test @Test
public void testSerde() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerde(Nd4jBackend backend) {
boolean print = false; boolean print = false;
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -122,7 +125,9 @@ public class EvalJsonTest extends BaseNd4jTest {
} }
@Test @Test
public void testSerdeExactRoc() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerdeExactRoc(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
boolean print = false; boolean print = false;
@ -200,7 +205,9 @@ public class EvalJsonTest extends BaseNd4jTest {
} }
@Test @Test
public void testJsonYamlCurves() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJsonYamlCurves(Nd4jBackend backend) {
ROC roc = new ROC(0); ROC roc = new ROC(0);
INDArray evalLabel = INDArray evalLabel =
@ -252,7 +259,9 @@ public class EvalJsonTest extends BaseNd4jTest {
} }
@Test @Test
public void testJsonWithCustomThreshold() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJsonWithCustomThreshold(Nd4jBackend backend) {
//Evaluation - binary threshold //Evaluation - binary threshold
Evaluation e = new Evaluation(0.25); Evaluation e = new Evaluation(0.25);

View File

@ -21,8 +21,10 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -39,11 +41,8 @@ import static org.junit.jupiter.api.Assertions.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
public class EvalTest extends BaseNd4jTest { public class EvalTest extends BaseNd4jTestWithBackends {
public EvalTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -52,7 +51,9 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
public void testEval() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEval(Nd4jBackend backend) {
int classNum = 5; int classNum = 5;
Evaluation eval = new Evaluation (classNum); Evaluation eval = new Evaluation (classNum);
@ -91,7 +92,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testEval2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEval2(Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
Evaluation first = null; Evaluation first = null;
@ -150,7 +153,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testStringListLabels() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStringListLabels(Nd4jBackend backend) {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
@ -167,7 +172,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testStringHashLabels() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStringHashLabels(Nd4jBackend backend) {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
@ -184,7 +191,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvalMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalMasking(Nd4jBackend backend) {
int miniBatch = 5; int miniBatch = 5;
int nOut = 3; int nOut = 3;
int tsLength = 6; int tsLength = 6;
@ -251,7 +260,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testFalsePerfectRecall() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFalsePerfectRecall(Nd4jBackend backend) {
int testSize = 100; int testSize = 100;
int numClasses = 5; int numClasses = 5;
int winner = 1; int winner = 1;
@ -284,7 +295,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationMerging(Nd4jBackend backend) {
int nRows = 20; int nRows = 20;
int nCols = 3; int nCols = 3;
@ -358,7 +371,9 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
public void testSingleClassBinaryClassification() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleClassBinaryClassification(Nd4jBackend backend) {
Evaluation eval = new Evaluation(1); Evaluation eval = new Evaluation(1);
@ -387,7 +402,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvalInvalid() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalInvalid(Nd4jBackend backend) {
Evaluation e = new Evaluation(5); Evaluation e = new Evaluation(5);
e.eval(0, 1); e.eval(0, 1);
e.eval(1, 0); e.eval(1, 0);
@ -400,7 +417,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvalMethods() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalMethods(Nd4jBackend backend) {
//Check eval(int,int) vs. eval(INDArray,INDArray) //Check eval(int,int) vs. eval(INDArray,INDArray)
Evaluation e1 = new Evaluation(4); Evaluation e1 = new Evaluation(4);
@ -443,7 +462,9 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
public void testTopNAccuracy() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopNAccuracy(Nd4jBackend backend) {
Evaluation e = new Evaluation(null, 3); Evaluation e = new Evaluation(null, 3);
@ -504,7 +525,9 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
public void testTopNAccuracyMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopNAccuracyMerging(Nd4jBackend backend) {
Evaluation e1 = new Evaluation(null, 3); Evaluation e1 = new Evaluation(null, 3);
Evaluation e2 = new Evaluation(null, 3); Evaluation e2 = new Evaluation(null, 3);
@ -552,7 +575,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testBinaryCase() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBinaryCase(Nd4jBackend backend) {
INDArray ones10 = Nd4j.ones(10, 1); INDArray ones10 = Nd4j.ones(10, 1);
INDArray ones4 = Nd4j.ones(4, 1); INDArray ones4 = Nd4j.ones(4, 1);
INDArray zeros4 = Nd4j.zeros(4, 1); INDArray zeros4 = Nd4j.zeros(4, 1);
@ -581,7 +606,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testF1FBeta_MicroMacroAveraging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testF1FBeta_MicroMacroAveraging(Nd4jBackend backend) {
//Confusion matrix: rows = actual, columns = predicted //Confusion matrix: rows = actual, columns = predicted
//[3, 1, 0] //[3, 1, 0]
//[2, 2, 1] //[2, 2, 1]
@ -722,7 +749,9 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
public void testConfusionMatrixStats() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrixStats(Nd4jBackend backend) {
Evaluation e = new Evaluation(); Evaluation e = new Evaluation();
@ -743,6 +772,8 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalBinaryMetrics(){ public void testEvalBinaryMetrics(){
Evaluation ePosClass1_nOut2 = new Evaluation(2, 1); Evaluation ePosClass1_nOut2 = new Evaluation(2, 1);
@ -864,6 +895,8 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrixString(){ public void testConfusionMatrixString(){
Evaluation e = new Evaluation(Arrays.asList("a","b","c")); Evaluation e = new Evaluation(Arrays.asList("a","b","c"));
@ -914,6 +947,8 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationNaNs(){ public void testEvaluationNaNs(){
Evaluation e = new Evaluation(); Evaluation e = new Evaluation();
@ -929,6 +964,8 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){ public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1023,6 +1060,8 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLabelReset(){ public void testLabelReset(){
Map<Integer,String> m = new HashMap<>(); Map<Integer,String> m = new HashMap<>();
@ -1056,6 +1095,8 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalStatsBinaryCase(){ public void testEvalStatsBinaryCase(){
//Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case //Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,11 +40,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.nd4j.evaluation.classification.EvaluationBinary.Metric.*; import static org.nd4j.evaluation.classification.EvaluationBinary.Metric.*;
public class EvaluationBinaryTest extends BaseNd4jTest { public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
public EvaluationBinaryTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -50,7 +49,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinary() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary(Nd4jBackend backend) {
//Compare EvaluationBinary to Evaluation class //Compare EvaluationBinary to Evaluation class
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
EvaluationBinary first = null; EvaluationBinary first = null;
@ -136,7 +137,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinaryMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryMerging(Nd4jBackend backend) {
int nOut = 4; int nOut = 4;
int[] shape1 = {30, nOut}; int[] shape1 = {30, nOut};
int[] shape2 = {50, nOut}; int[] shape2 = {50, nOut};
@ -163,7 +166,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinaryPerOutputMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryPerOutputMasking(Nd4jBackend backend) {
//Provide a mask array: "ignore" the masked steps //Provide a mask array: "ignore" the masked steps
@ -206,7 +211,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testTimeSeriesEval() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTimeSeriesEval(Nd4jBackend backend) {
int[] shape = {2, 4, 3}; int[] shape = {2, 4, 3};
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -230,7 +237,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinaryWithROC() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryWithROC(Nd4jBackend backend) {
//Simple test for nested ROCBinary in EvaluationBinary //Simple test for nested ROCBinary in EvaluationBinary
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -247,7 +256,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
@Test @Test
public void testEvaluationBinary3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -281,7 +292,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinary4d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -315,7 +328,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinary3dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -376,7 +391,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinary4dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -21,8 +21,10 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,19 +41,18 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class EvaluationCalibrationTest extends BaseNd4jTest { public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
public EvaluationCalibrationTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering () {
return 'c'; return 'c';
} }
@Test @Test
public void testReliabilityDiagram() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReliabilityDiagram (Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
EvaluationCalibration first = null; EvaluationCalibration first = null;
@ -143,7 +144,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
} }
@Test @Test
public void testLabelAndPredictionCounts() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLabelAndPredictionCounts (Nd4jBackend backend) {
int minibatch = 50; int minibatch = 50;
int nClasses = 3; int nClasses = 3;
@ -171,7 +174,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
} }
@Test @Test
public void testResidualPlots() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResidualPlots (Nd4jBackend backend) {
int minibatch = 50; int minibatch = 50;
int nClasses = 3; int nClasses = 3;
@ -272,6 +277,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){ public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -366,7 +373,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationCalibration3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCalibration3d (Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -398,7 +407,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationCalibration3dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCalibration3dMasking (Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);

View File

@ -23,6 +23,8 @@ package org.nd4j.evaluation;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -30,17 +32,14 @@ import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
public class NewInstanceTest extends BaseNd4jTest { public class NewInstanceTest extends BaseNd4jTestWithBackends {
public NewInstanceTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -48,7 +47,9 @@ public class NewInstanceTest extends BaseNd4jTest {
} }
@Test @Test
public void testNewInstances() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewInstances(Nd4jBackend backend) {
boolean print = true; boolean print = true;
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -21,10 +21,12 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,11 +41,7 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class ROCBinaryTest extends BaseNd4jTest { public class ROCBinaryTest extends BaseNd4jTestWithBackends {
public ROCBinaryTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -51,7 +49,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testROCBinary() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary(Nd4jBackend backend) {
//Compare ROCBinary to ROC class //Compare ROCBinary to ROC class
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
@ -146,7 +146,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testRocBinaryMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBinaryMerging(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact for (int nSteps : new int[]{30, 0}) { //0 == exact
int nOut = 4; int nOut = 4;
int[] shape1 = {30, nOut}; int[] shape1 = {30, nOut};
@ -176,7 +178,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
@Test @Test
public void testROCBinaryPerOutputMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact for (int nSteps : new int[]{30, 0}) { //0 == exact
@ -216,7 +220,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
@Test @Test
public void testROCBinary3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -250,7 +256,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testROCBinary4d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -284,7 +292,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testROCBinary3dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -345,7 +355,9 @@ public class ROCBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testROCBinary4dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -21,12 +21,14 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -39,11 +41,8 @@ import java.util.*;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class ROCTest extends BaseNd4jTest { public class ROCTest extends BaseNd4jTestWithBackends {
public ROCTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -84,7 +83,9 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testRocBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBasic(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax) //2 outputs here - probability distribution over classes (softmax)
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601}, {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
@ -127,7 +128,9 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testRocBasicSingleClass() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBasicSingleClass(Nd4jBackend backend) {
//1 output here - single probability value (sigmoid) //1 output here - single probability value (sigmoid)
//add 0.001 to avoid numerical/rounding issues (float vs. double, etc) //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
@ -165,7 +168,9 @@ public class ROCTest extends BaseNd4jTest {
@Test @Test
public void testRoc() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRoc(Nd4jBackend backend) {
//Previous tests allowed for a perfect classifier with right threshold... //Previous tests allowed for a perfect classifier with right threshold...
INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}}); INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});
@ -250,7 +255,9 @@ public class ROCTest extends BaseNd4jTest {
@Test @Test
public void testRocTimeSeriesNoMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
//Same as first test... //Same as first test...
//2 outputs here - probability distribution over classes (softmax) //2 outputs here - probability distribution over classes (softmax)
@ -297,7 +304,9 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testRocTimeSeriesMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax) //2 outputs here - probability distribution over classes (softmax)
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601}, {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
@ -347,7 +356,9 @@ public class ROCTest extends BaseNd4jTest {
@Test @Test
public void testCompareRocAndRocMultiClass() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
//For 2 class case: ROC and Multi-class ROC should be the same... //For 2 class case: ROC and Multi-class ROC should be the same...
@ -377,7 +388,9 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testCompare2Vs3Classes() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCompare2Vs3Classes(Nd4jBackend backend) {
//ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together... //ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together...
//Both methods implement one vs. all ROC/AUC in different ways //Both methods implement one vs. all ROC/AUC in different ways
@ -426,7 +439,9 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testROCMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMerging(Nd4jBackend backend) {
int nArrays = 10; int nArrays = 10;
int minibatch = 64; int minibatch = 64;
int nROCs = 3; int nROCs = 3;
@ -471,7 +486,9 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testROCMerging2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMerging2(Nd4jBackend backend) {
int nArrays = 10; int nArrays = 10;
int minibatch = 64; int minibatch = 64;
int exactAllocBlockSize = 10; int exactAllocBlockSize = 10;
@ -516,7 +533,9 @@ public class ROCTest extends BaseNd4jTest {
@Test @Test
public void testROCMultiMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMultiMerging(Nd4jBackend backend) {
int nArrays = 10; int nArrays = 10;
int minibatch = 64; int minibatch = 64;
@ -564,7 +583,9 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testAUCPrecisionRecall() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAUCPrecisionRecall(Nd4jBackend backend) {
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob //Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
//at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0 //at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0
//at threshold 0.25 to 0.33: tp=2, fp=0, fn=0, tn=1 prec=2/2=1, recall=2/2=1 //at threshold 0.25 to 0.33: tp=2, fp=0, fn=0, tn=1 prec=2/2=1, recall=2/2=1
@ -611,7 +632,9 @@ public class ROCTest extends BaseNd4jTest {
@Test @Test
public void testRocAucExact() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocAucExact(Nd4jBackend backend) {
//Check the implementation vs. Scikitlearn //Check the implementation vs. Scikitlearn
/* /*
@ -774,7 +797,9 @@ public class ROCTest extends BaseNd4jTest {
@Test @Test
public void rocExactEdgeCaseReallocation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
//Set reallocation block size to say 20, but then evaluate a 100-length array //Set reallocation block size to say 20, but then evaluate a 100-length array
@ -786,7 +811,9 @@ public class ROCTest extends BaseNd4jTest {
@Test @Test
public void testPrecisionRecallCurveGetPointMethods() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
double[] threshold = new double[101]; double[] threshold = new double[101];
double[] precision = threshold; double[] precision = threshold;
double[] recall = new double[101]; double[] recall = new double[101];
@ -822,7 +849,9 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testPrecisionRecallCurveConfusion() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
//Sanity check: values calculated from the confusion matrix should match the PR curve values //Sanity check: values calculated from the confusion matrix should match the PR curve values
for (boolean removeRedundantPts : new boolean[] {true, false}) { for (boolean removeRedundantPts : new boolean[] {true, false}) {
@ -861,6 +890,8 @@ public class ROCTest extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocMerge(){ public void testRocMerge(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -905,6 +936,8 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocMultiMerge(){ public void testRocMultiMerge(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -954,6 +987,8 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBinaryMerge(){ public void testRocBinaryMerge(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -999,6 +1034,8 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentationBinary(){ public void testSegmentationBinary(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1089,6 +1126,8 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){ public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
public class RegressionEvalTest extends BaseNd4jTest { public class RegressionEvalTest extends BaseNd4jTestWithBackends {
public RegressionEvalTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -52,7 +51,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test() @Test()
public void testEvalParameters() { public void testEvalParameters(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
int specCols = 5; int specCols = 5;
INDArray labels = Nd4j.ones(3); INDArray labels = Nd4j.ones(3);
@ -65,7 +64,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testPerfectPredictions() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPerfectPredictions(Nd4jBackend backend) {
int nCols = 5; int nCols = 5;
int nTestArrays = 100; int nTestArrays = 100;
@ -92,7 +93,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testKnownValues() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testKnownValues(Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
RegressionEvaluation first = null; RegressionEvaluation first = null;
@ -148,7 +151,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
@Test @Test
public void testRegressionEvaluationMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvaluationMerging(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nRows = 20; int nRows = 20;
@ -189,7 +194,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testRegressionEvalPerOutputMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvalPerOutputMasking(Nd4jBackend backend) {
INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}}); INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}});
@ -216,6 +223,8 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvalTimeSeriesSplit(){ public void testRegressionEvalTimeSeriesSplit(){
INDArray out1 = Nd4j.rand(new int[]{3, 5, 20}); INDArray out1 = Nd4j.rand(new int[]{3, 5, 20});
@ -238,7 +247,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testRegressionEval3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -270,7 +281,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testRegressionEval4d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -302,7 +315,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testRegressionEval3dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -361,7 +376,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testRegressionEval4dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -22,10 +22,12 @@ package org.nd4j.evaluation;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -34,11 +36,8 @@ import java.nio.charset.StandardCharsets;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestLegacyJsonLoading extends BaseNd4jTest { public class TestLegacyJsonLoading extends BaseNd4jTestWithBackends {
public TestLegacyJsonLoading(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -46,7 +45,9 @@ public class TestLegacyJsonLoading extends BaseNd4jTest {
} }
@Test @Test
public void testEvalLegacyFormat() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalLegacyFormat(Nd4jBackend backend) throws Exception {
File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile(); File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile();
String s = FileUtils.readFileToString(f, StandardCharsets.UTF_8); String s = FileUtils.readFileToString(f, StandardCharsets.UTF_8);

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,17 +39,14 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class AveragingTests extends BaseNd4jTest { public class AveragingTests extends BaseNd4jTestWithBackends {
private final int THREADS = 16; private final int THREADS = 16;
private final int LENGTH = 51200 * 4; private final int LENGTH = 51200 * 4;
DataType initialType; DataType initialType = Nd4j.dataType();
public AveragingTests(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
public void setUp() { public void setUp() {
@ -63,7 +61,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test @Test
public void testSingleDeviceAveraging1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleDeviceAveraging1(Nd4jBackend backend) {
INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0); INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0);
INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0); INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0);
INDArray array3 = Nd4j.valueArrayOf(LENGTH, 3.0); INDArray array3 = Nd4j.valueArrayOf(LENGTH, 3.0);
@ -110,7 +110,9 @@ public class AveragingTests extends BaseNd4jTest {
} }
@Test @Test
public void testSingleDeviceAveraging2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleDeviceAveraging2(Nd4jBackend backend) {
INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH); INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH);
List<INDArray> arrays = new ArrayList<>(); List<INDArray> arrays = new ArrayList<>();
for (int i = 0; i < THREADS; i++) for (int i = 0; i < THREADS; i++)
@ -127,7 +129,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test @Test
public void testAccumulation1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation1(Nd4jBackend backend) {
INDArray array1 = Nd4j.create(100).assign(1.0); INDArray array1 = Nd4j.create(100).assign(1.0);
INDArray array2 = Nd4j.create(100).assign(2.0); INDArray array2 = Nd4j.create(100).assign(2.0);
INDArray array3 = Nd4j.create(100).assign(3.0); INDArray array3 = Nd4j.create(100).assign(3.0);
@ -140,7 +144,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test @Test
public void testAccumulation2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation2(Nd4jBackend backend) {
INDArray array1 = Nd4j.create(100).assign(1.0); INDArray array1 = Nd4j.create(100).assign(1.0);
INDArray array2 = Nd4j.create(100).assign(2.0); INDArray array2 = Nd4j.create(100).assign(2.0);
INDArray array3 = Nd4j.create(100).assign(3.0); INDArray array3 = Nd4j.create(100).assign(3.0);
@ -155,7 +161,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test @Test
public void testAccumulation3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation3(Nd4jBackend backend) {
// we want to ensure that cuda backend is able to launch this op on cpu // we want to ensure that cuda backend is able to launch this op on cpu
Nd4j.getAffinityManager().allowCrossDeviceAccess(false); Nd4j.getAffinityManager().allowCrossDeviceAccess(false);

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -34,15 +35,14 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@Slf4j @Slf4j
public class DataTypeTest extends BaseNd4jTest { public class DataTypeTest extends BaseNd4jTestWithBackends {
public DataTypeTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testDataTypes() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDataTypes(Nd4jBackend backend) throws Exception {
for (val type : DataType.values()) { for (val type : DataType.values()) {
if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type)) if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type))
continue; continue;

View File

@ -21,20 +21,17 @@
package org.nd4j.linalg; package org.nd4j.linalg;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@RunWith(Parameterized.class)
public class InputValidationTests extends BaseNd4jTest {
public InputValidationTests(Nd4jBackend backend) { public class InputValidationTests extends BaseNd4jTestWithBackends {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -45,7 +42,9 @@ public class InputValidationTests extends BaseNd4jTest {
///////////////////// Broadcast Tests /////////////////////// ///////////////////// Broadcast Tests ///////////////////////
@Test @Test
public void testInvalidColVectorOp1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidColVectorOp1(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray col = Nd4j.create(5, 1); INDArray col = Nd4j.create(5, 1);
try { try {
@ -57,7 +56,9 @@ public class InputValidationTests extends BaseNd4jTest {
} }
@Test @Test
public void testInvalidColVectorOp2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidColVectorOp2(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray col = Nd4j.create(5, 1); INDArray col = Nd4j.create(5, 1);
try { try {
@ -69,7 +70,9 @@ public class InputValidationTests extends BaseNd4jTest {
} }
@Test @Test
public void testInvalidRowVectorOp1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidRowVectorOp1(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray row = Nd4j.create(1, 5); INDArray row = Nd4j.create(1, 5);
try { try {
@ -81,7 +84,9 @@ public class InputValidationTests extends BaseNd4jTest {
} }
@Test @Test
public void testInvalidRowVectorOp2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidRowVectorOp2(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray row = Nd4j.create(1, 5); INDArray row = Nd4j.create(1, 5);
try { try {

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.lang3.RandomUtils; import org.apache.commons.lang3.RandomUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
@ -47,14 +48,13 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class LoneTest extends BaseNd4jTest { public class LoneTest extends BaseNd4jTestWithBackends {
public LoneTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testSoftmaxStability() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxStability(Nd4jBackend backend) {
INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose(); INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose();
// System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo())); // System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1); INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1);
@ -68,7 +68,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void testFlattenedView() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFlattenedView(Nd4jBackend backend) {
int rows = 8; int rows = 8;
int cols = 8; int cols = 8;
int dim2 = 4; int dim2 = 4;
@ -104,7 +106,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void testIndexingColVec() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingColVec(Nd4jBackend backend) {
int elements = 5; int elements = 5;
INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements); INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements);
INDArray colVector = rowVector.transpose(); INDArray colVector = rowVector.transpose();
@ -123,7 +127,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void concatScalarVectorIssue() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void concatScalarVectorIssue(Nd4jBackend backend) {
//A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars //A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars
INDArray arr1 = Nd4j.create(1, 1); INDArray arr1 = Nd4j.create(1, 1);
INDArray arr2 = Nd4j.create(1, 8); INDArray arr2 = Nd4j.create(1, 8);
@ -133,7 +139,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void reshapeTensorMmul() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void reshapeTensorMmul(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2); INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2);
INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2); INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2);
int[][] axes = new int[2][]; int[][] axes = new int[2][];
@ -145,7 +153,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void maskWhenMerge() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void maskWhenMerge(Nd4jBackend backend) {
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5)); DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3)); DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
List<DataSet> dataSetList = new ArrayList<DataSet>(); List<DataSet> dataSetList = new ArrayList<DataSet>();
@ -160,7 +170,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void testRelu() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRelu(Nd4jBackend backend) {
INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4); INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4); INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
INDArray b = Nd4j.getExecutioner().exec(new Tanh(aA)); INDArray b = Nd4j.getExecutioner().exec(new Tanh(aA));
@ -172,7 +184,7 @@ public class LoneTest extends BaseNd4jTest {
@Test @Test
//broken at a threshold //broken at a threshold
public void testArgMax() { public void testArgMax(Nd4jBackend backend) {
int max = 63; int max = 63;
INDArray A = Nd4j.linspace(1, max, max).reshape(1, max); INDArray A = Nd4j.linspace(1, max, max).reshape(1, max);
int currentArgMax = Nd4j.argMax(A).getInt(0); int currentArgMax = Nd4j.argMax(A).getInt(0);
@ -186,7 +198,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void testRPF() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRPF(Nd4jBackend backend) {
val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3); val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3);
log.info("--------"); log.info("--------");
@ -199,7 +213,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void testConcat3D_Vstack_C() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat3D_Vstack_C(Nd4jBackend backend) {
val shape = new long[]{1, 1000, 20}; val shape = new long[]{1, 1000, 20};
List<INDArray> cArrays = new ArrayList<>(); List<INDArray> cArrays = new ArrayList<>();
@ -229,7 +245,9 @@ public class LoneTest extends BaseNd4jTest {
@Test @Test
public void testGetRow1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRow1(Nd4jBackend backend) {
INDArray array = Nd4j.create(10000, 10000); INDArray array = Nd4j.create(10000, 10000);
//Thread.sleep(10000); //Thread.sleep(10000);
@ -256,7 +274,7 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test() @Test()
public void checkIllegalElementOps() { public void checkIllegalElementOps(Nd4jBackend backend) {
assertThrows(Exception.class,() -> { assertThrows(Exception.class,() -> {
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5); INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
INDArray B = A.dup().reshape(2, 2, 5); INDArray B = A.dup().reshape(2, 2, 5);
@ -268,7 +286,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void checkSliceofSlice() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void checkSliceofSlice(Nd4jBackend backend) {
/* /*
Issue 1: Slice of slice with c order and f order views are not equal Issue 1: Slice of slice with c order and f order views are not equal
@ -308,7 +328,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void checkWithReshape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void checkWithReshape(Nd4jBackend backend) {
INDArray arr = Nd4j.create(1, 3); INDArray arr = Nd4j.create(1, 3);
INDArray reshaped = arr.reshape('f', 3, 1); INDArray reshaped = arr.reshape('f', 3, 1);
for (int i=0;i<reshaped.length();i++) { for (int i=0;i<reshaped.length();i++) {

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg; package org.nd4j.linalg;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -28,11 +30,8 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class MmulBug extends BaseNd4jTest { public class MmulBug extends BaseNd4jTestWithBackends {
public MmulBug(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -40,7 +39,9 @@ public class MmulBug extends BaseNd4jTest {
} }
@Test @Test
public void simpleTest() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void simpleTest(Nd4jBackend backend) {
INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}}); INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}});
m1 = m1.reshape(2, 2); m1 = m1.reshape(2, 2);

View File

@ -25,8 +25,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -58,19 +59,14 @@ import static org.junit.jupiter.api.Assertions.*;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
@Slf4j @Slf4j
public class NDArrayTestsFortran extends BaseNd4jTest { public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
public NDArrayTestsFortran(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testScalarOps() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarOps(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3}); INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
assertEquals(27d, n.length(), 1e-1); assertEquals(27d, n.length(), 1e-1);
n.addi(Nd4j.scalar(1d)); n.addi(Nd4j.scalar(1d));
@ -88,7 +84,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testColumnMmul() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumnMmul(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 10, 18, DataType.FLOAT).data(); DataBuffer data = Nd4j.linspace(1, 10, 18, DataType.FLOAT).data();
INDArray x2 = Nd4j.create(data, new long[] {2, 3, 3}); INDArray x2 = Nd4j.create(data, new long[] {2, 3, 3});
data = Nd4j.linspace(1, 12, 9, DataType.FLOAT).data(); data = Nd4j.linspace(1, 12, 9, DataType.FLOAT).data();
@ -119,7 +117,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testRowVectorGemm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowVectorGemm(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE);
INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE); INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE);
INDArray result = linspace.mmul(other); INDArray result = linspace.mmul(other);
@ -130,13 +130,17 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testRepmat() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRepmat(Nd4jBackend backend) {
INDArray rowVector = Nd4j.create(1, 4); INDArray rowVector = Nd4j.create(1, 4);
INDArray repmat = rowVector.repmat(4, 4); INDArray repmat = rowVector.repmat(4, 4);
assertTrue(Arrays.equals(new long[] {4, 16}, repmat.shape())); assertTrue(Arrays.equals(new long[] {4, 16}, repmat.shape()));
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWrite() throws Exception { public void testReadWrite() throws Exception {
INDArray write = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray write = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -152,6 +156,8 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWriteDouble() throws Exception { public void testReadWriteDouble() throws Exception {
INDArray write = Nd4j.linspace(1, 4, 4, DataType.FLOAT); INDArray write = Nd4j.linspace(1, 4, 4, DataType.FLOAT);
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -168,18 +174,17 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiThreading() throws Exception { public void testMultiThreading() throws Exception {
ExecutorService ex = ExecutorServiceProvider.getExecutorService(); ExecutorService ex = ExecutorServiceProvider.getExecutorService();
List<Future<?>> list = new ArrayList<>(100); List<Future<?>> list = new ArrayList<>(100);
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
Future<?> future = ex.submit(new Runnable() { Future<?> future = ex.submit(() -> {
@Override
public void run() {
INDArray dot = Nd4j.linspace(1, 8, 8, DataType.DOUBLE); INDArray dot = Nd4j.linspace(1, 8, 8, DataType.DOUBLE);
// System.out.println(Transforms.sigmoid(dot)); // System.out.println(Transforms.sigmoid(dot));
Transforms.sigmoid(dot); Transforms.sigmoid(dot);
}
}); });
list.add(future); list.add(future);
} }
@ -191,7 +196,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testBroadcastingGenerated() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastingGenerated(Nd4jBackend backend) {
int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10); int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10);
List<List<Pair<INDArray, String>>> broadCastList = new ArrayList<>(broadcastShape.length); List<List<Pair<INDArray, String>>> broadCastList = new ArrayList<>(broadcastShape.length);
for (int[] shape : broadcastShape) { for (int[] shape : broadcastShape) {
@ -216,7 +223,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testBroadCasting() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadCasting(Nd4jBackend backend) {
INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE); INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE);
INDArray ret = first.broadcast(3, 4); INDArray ret = first.broadcast(3, 4);
INDArray testRet = Nd4j.create(new double[][] {{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}}); INDArray testRet = Nd4j.create(new double[][] {{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}});
@ -229,14 +238,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testOneTensor() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneTensor(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1); INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1);
INDArray matrixToBroadcast = Nd4j.ones(1, 1); INDArray matrixToBroadcast = Nd4j.ones(1, 1);
assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr); assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr);
} }
@Test @Test
public void testSortWithIndicesDescending() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortWithIndicesDescending(Nd4jBackend backend) {
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
//indices,data //indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, false); INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, false);
@ -247,7 +260,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testSortDeadlock() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortDeadlock(Nd4jBackend backend) {
val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768); val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768);
val sorted = Nd4j.sort(toSort.dup(), 1, false); val sorted = Nd4j.sort(toSort.dup(), 1, false);
@ -255,7 +270,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testSortWithIndices() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortWithIndices(Nd4jBackend backend) {
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
//indices,data //indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, true); INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, true);
@ -266,14 +283,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testNd4jSortScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jSortScalar(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1); INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1);
INDArray sorted = Nd4j.sort(linspace, 1, false); INDArray sorted = Nd4j.sort(linspace, 1, false);
// System.out.println(sorted); // System.out.println(sorted);
} }
@Test @Test
public void testSwapAxesFortranOrder() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSwapAxesFortranOrder(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE); INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE);
for (int i = 0; i < n.slices(); i++) { for (int i = 0; i < n.slices(); i++) {
INDArray nSlice = n.slice(i); INDArray nSlice = n.slice(i);
@ -292,7 +313,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testDimShuffle() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDimShuffle(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false}); INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false});
assertTrue(Arrays.equals(new long[] {2, 1, 2}, twoOneTwo.shape())); assertTrue(Arrays.equals(new long[] {2, 1, 2}, twoOneTwo.shape()));
@ -303,7 +326,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGetVsGetScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetVsGetScalar(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
float element = a.getFloat(0, 1); float element = a.getFloat(0, 1);
double element2 = a.getDouble(0, 1); double element2 = a.getDouble(0, 1);
@ -316,7 +341,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testDivide() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDivide(Nd4jBackend backend) {
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2}); INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
INDArray div = two.div(two); INDArray div = two.div(two);
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage()); assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage());
@ -330,7 +357,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testSigmoid() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSigmoid(Nd4jBackend backend) {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f}); INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
INDArray sigmoid = Transforms.sigmoid(n, false); INDArray sigmoid = Transforms.sigmoid(n, false);
@ -339,7 +368,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testNeg() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNeg(Nd4jBackend backend) {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4}); INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
INDArray neg = Transforms.neg(n); INDArray neg = Transforms.neg(n);
@ -349,7 +380,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testCosineSim() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCosineSim(Nd4jBackend backend) {
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
double sim = Transforms.cosineSim(vec1, vec2); double sim = Transforms.cosineSim(vec1, vec2);
@ -364,7 +397,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testExp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExp(Nd4jBackend backend) {
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f}); INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f});
INDArray exped = Transforms.exp(n); INDArray exped = Transforms.exp(n);
@ -374,7 +409,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalar(Nd4jBackend backend) {
INDArray a = Nd4j.scalar(1.0f); INDArray a = Nd4j.scalar(1.0f);
assertEquals(true, a.isScalar()); assertEquals(true, a.isScalar());
@ -386,7 +423,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testWrap() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testWrap(Nd4jBackend backend) {
int[] shape = {2, 4}; int[] shape = {2, 4};
INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]); INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]);
INDArray n = d; INDArray n = d;
@ -411,7 +450,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGetRowFortran() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowFortran(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2}); INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2});
INDArray column = Nd4j.create(new float[] {1, 3}); INDArray column = Nd4j.create(new float[] {1, 3});
INDArray column2 = Nd4j.create(new float[] {2, 4}); INDArray column2 = Nd4j.create(new float[] {2, 4});
@ -424,7 +465,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGetColumnFortran() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnFortran(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2}); INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
INDArray column = Nd4j.create(new double[] {1, 2}); INDArray column = Nd4j.create(new double[] {1, 2});
INDArray column2 = Nd4j.create(new double[] {3, 4}); INDArray column2 = Nd4j.create(new double[] {3, 4});
@ -438,7 +481,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testGetColumns() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumns(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE); INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE);
// log.info("Original: {}", matrix); // log.info("Original: {}", matrix);
INDArray matrixGet = matrix.getColumns(1, 2); INDArray matrixGet = matrix.getColumns(1, 2);
@ -452,7 +497,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testVectorInit() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorInit(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(); DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data();
INDArray arr = Nd4j.create(data, new long[] {1, 4}); INDArray arr = Nd4j.create(data, new long[] {1, 4});
assertEquals(true, arr.isRowVector()); assertEquals(true, arr.isRowVector());
@ -465,7 +512,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testAssignOffset() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssignOffset(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(5, 5); INDArray arr = Nd4j.ones(5, 5);
INDArray row = arr.slice(1); INDArray row = arr.slice(1);
row.assign(1); row.assign(1);
@ -473,7 +522,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testColumns() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumns(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE); INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE);
INDArray column = Nd4j.create(new double[] {1, 2, 3}); INDArray column = Nd4j.create(new double[] {1, 2, 3});
arr.putColumn(0, column); arr.putColumn(0, column);
@ -511,7 +562,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testPutRow() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRow(Nd4jBackend backend) {
INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray n = d.dup(); INDArray n = d.dup();
@ -570,7 +623,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testInplaceTranspose() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInplaceTranspose(Nd4jBackend backend) {
INDArray test = Nd4j.rand(3, 4); INDArray test = Nd4j.rand(3, 4);
INDArray orig = test.dup(); INDArray orig = test.dup();
INDArray transposei = test.transposei(); INDArray transposei = test.transposei();
@ -585,7 +640,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testMmulF() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulF(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data(); DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data();
INDArray n = Nd4j.create(data, new long[] {1, 10}); INDArray n = Nd4j.create(data, new long[] {1, 10});
@ -603,7 +660,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testRowsColumns() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowsColumns(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data(); DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data();
INDArray rows = Nd4j.create(data, new long[] {2, 3}); INDArray rows = Nd4j.create(data, new long[] {2, 3});
assertEquals(2, rows.rows()); assertEquals(2, rows.rows());
@ -619,7 +678,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testTranspose() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTranspose(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4}); INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4});
INDArray transpose = n.transpose(); INDArray transpose = n.transpose();
assertEquals(n.length(), transpose.length()); assertEquals(n.length(), transpose.length());
@ -647,7 +708,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testAddMatrix() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddMatrix(Nd4jBackend backend) {
INDArray five = Nd4j.ones(5); INDArray five = Nd4j.ones(5);
five.addi(five.dup()); five.addi(five.dup());
INDArray twos = Nd4j.valueArrayOf(5, 2); INDArray twos = Nd4j.valueArrayOf(5, 2);
@ -658,7 +721,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testMMul() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMMul(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}}); INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
@ -669,7 +734,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testPutSlice() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutSlice(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3); INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3);
INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3); INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3);
Nd4j.exec(new PrintVariable(newSlice)); Nd4j.exec(new PrintVariable(newSlice));
@ -680,7 +747,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testRowVectorMultipleIndices() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4); INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
linear.putScalar(new long[] {0, 1}, 1); linear.putScalar(new long[] {0, 1}, 1);
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage()); assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage());
@ -689,7 +758,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testDim1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDim1(Nd4jBackend backend) {
INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1); INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1);
INDArray same = sum.dup(); INDArray same = sum.dup();
assertEquals(same.sum(1), sum.reshape(2)); assertEquals(same.sum(1), sum.reshape(2));
@ -697,7 +768,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testEps() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEps(Nd4jBackend backend) {
val ones = Nd4j.ones(5); val ones = Nd4j.ones(5);
val res = Nd4j.createUninitialized(DataType.BOOL, 5); val res = Nd4j.createUninitialized(DataType.BOOL, 5);
assertTrue(Nd4j.getExecutioner().exec(new Eps(ones, ones, res)).all()); assertTrue(Nd4j.getExecutioner().exec(new Eps(ones, ones, res)).all());
@ -705,7 +778,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testLogDouble() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogDouble(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE);
INDArray log = Transforms.log(linspace); INDArray log = Transforms.log(linspace);
INDArray assertion = Nd4j.create(new double[] {0, 0.6931471805599453, 1.0986122886681098, 1.3862943611198906, 1.6094379124341005, 1.791759469228055}); INDArray assertion = Nd4j.create(new double[] {0, 0.6931471805599453, 1.0986122886681098, 1.3862943611198906, 1.6094379124341005, 1.791759469228055});
@ -713,28 +788,36 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testVectorSum() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum(Nd4jBackend backend) {
INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1); assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
} }
@Test @Test
public void testVectorSum2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum2(Nd4jBackend backend) {
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1); assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
} }
@Test @Test
public void testVectorSum3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum3(Nd4jBackend backend) {
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals(lin, lin2); assertEquals(lin, lin2);
} }
@Test @Test
public void testSmallSum() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSmallSum(Nd4jBackend backend) {
INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007}); INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007});
base.addi(1e-12); base.addi(1e-12);
INDArray assertion = Nd4j.create(new double[] {5.84333433, 3.054001}); INDArray assertion = Nd4j.create(new double[] {5.84333433, 3.054001});
@ -745,7 +828,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testPermute() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4}); INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4});
INDArray transpose = n.transpose(); INDArray transpose = n.transpose();
INDArray permute = n.permute(1, 0); INDArray permute = n.permute(1, 0);
@ -774,7 +859,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testAppendBias() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAppendBias(Nd4jBackend backend) {
INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose(); INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray test = Nd4j.appendBias(rand); INDArray test = Nd4j.appendBias(rand);
INDArray assertion = Nd4j.toFlattened(rand, Nd4j.scalar(DataType.DOUBLE, 1.0)).reshape(-1, 1); INDArray assertion = Nd4j.toFlattened(rand, Nd4j.scalar(DataType.DOUBLE, 1.0)).reshape(-1, 1);
@ -782,7 +869,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testRand() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRand(Nd4jBackend backend) {
INDArray rand = Nd4j.randn(5, 5); INDArray rand = Nd4j.randn(5, 5);
Nd4j.getDistributions().createUniform(0.4, 4).sample(5); Nd4j.getDistributions().createUniform(0.4, 4).sample(5);
Nd4j.getDistributions().createNormal(1, 5).sample(10); Nd4j.getDistributions().createNormal(1, 5).sample(10);
@ -794,7 +883,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testIdentity() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIdentity(Nd4jBackend backend) {
INDArray eye = Nd4j.eye(5); INDArray eye = Nd4j.eye(5);
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape())); assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
eye = Nd4j.eye(5); eye = Nd4j.eye(5);
@ -805,7 +896,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testColumnVectorOpsFortran() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumnVectorOpsFortran(Nd4jBackend backend) {
INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1}); INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1});
twoByTwo.addiColumnVector(toAdd); twoByTwo.addiColumnVector(toAdd);
@ -816,7 +909,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testRSubi() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRSubi(Nd4jBackend backend) {
INDArray n2 = Nd4j.ones(2); INDArray n2 = Nd4j.ones(2);
INDArray n2Assertion = Nd4j.zeros(2); INDArray n2Assertion = Nd4j.zeros(2);
INDArray nRsubi = n2.rsubi(1); INDArray nRsubi = n2.rsubi(1);
@ -826,7 +921,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testAssign() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
vector.assign(1); vector.assign(1);
assertEquals(Nd4j.ones(5).castTo(DataType.DOUBLE), vector); assertEquals(Nd4j.ones(5).castTo(DataType.DOUBLE), vector);
@ -843,7 +940,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testAddScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddScalar(Nd4jBackend backend) {
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0); INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
INDArray rdiv = div.add(1); INDArray rdiv = div.add(1);
INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 5.0); INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 5.0);
@ -851,7 +950,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testRdivScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRdivScalar(Nd4jBackend backend) {
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0); INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
INDArray rdiv = div.rdiv(1); INDArray rdiv = div.rdiv(1);
INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 0.25); INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 0.25);
@ -859,7 +960,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testRDivi() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRDivi(Nd4jBackend backend) {
INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0); INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0);
INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5); INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5);
INDArray nRsubi = n2.rdivi(2); INDArray nRsubi = n2.rdivi(2);
@ -869,7 +972,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testNumVectorsAlongDimension() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNumVectorsAlongDimension(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2);
assertEquals(12, arr.vectorsAlongDimension(2)); assertEquals(12, arr.vectorsAlongDimension(2));
} }
@ -877,7 +982,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testBroadCast() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadCast(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
INDArray broadCasted = n.broadcast(5, 4); INDArray broadCasted = n.broadcast(5, 4);
for (int i = 0; i < broadCasted.rows(); i++) { for (int i = 0; i < broadCasted.rows(); i++) {
@ -899,7 +1006,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testMatrix() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrix(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2}); INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2});
INDArray row = arr.getRow(0); INDArray row = arr.getRow(0);
@ -909,7 +1018,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testPutRowGetRowOrdering() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowGetRowOrdering(Nd4jBackend backend) {
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray put = Nd4j.create(new double[] {5, 6}); INDArray put = Nd4j.create(new double[] {5, 6});
row1.putRow(1, put); row1.putRow(1, put);
@ -931,7 +1042,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testSumWithRow1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSumWithRow1(Nd4jBackend backend) {
//Works: //Works:
INDArray array2d = Nd4j.ones(1, 10); INDArray array2d = Nd4j.ones(1, 10);
array2d.sum(0); //OK array2d.sum(0); //OK
@ -962,7 +1075,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testSumWithRow2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSumWithRow2(Nd4jBackend backend) {
//All sums in this method execute without exceptions. //All sums in this method execute without exceptions.
INDArray array3d = Nd4j.ones(2, 10, 10); INDArray array3d = Nd4j.ones(2, 10, 10);
array3d.sum(0); array3d.sum(0);
@ -985,7 +1100,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testPutRowFortran() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowFortran(Nd4jBackend backend) {
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE); INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE);
INDArray put = Nd4j.create(new double[] {5, 6}); INDArray put = Nd4j.create(new double[] {5, 6});
row1.putRow(1, put); row1.putRow(1, put);
@ -998,7 +1115,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testElementWiseOps() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testElementWiseOps(Nd4jBackend backend) {
INDArray n1 = Nd4j.scalar(1); INDArray n1 = Nd4j.scalar(1);
INDArray n2 = Nd4j.scalar(2); INDArray n2 = Nd4j.scalar(2);
INDArray nClone = n1.add(n2); INDArray nClone = n1.add(n2);
@ -1021,7 +1140,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testRollAxis() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRollAxis(Nd4jBackend backend) {
INDArray toRoll = Nd4j.ones(3, 4, 5, 6); INDArray toRoll = Nd4j.ones(3, 4, 5, 6);
assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape()); assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape());
val shape = Nd4j.rollAxis(toRoll, 3).shape(); val shape = Nd4j.rollAxis(toRoll, 3).shape();
@ -1030,7 +1151,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
@Disabled @Disabled
public void testTensorDot() { public void testTensorDot(Nd4jBackend backend) {
INDArray oneThroughSixty = Nd4j.arange(60).reshape('f', 3, 4, 5).castTo(DataType.DOUBLE); INDArray oneThroughSixty = Nd4j.arange(60).reshape('f', 3, 4, 5).castTo(DataType.DOUBLE);
INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape('f', 4, 3, 2).castTo(DataType.DOUBLE); INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape('f', 4, 3, 2).castTo(DataType.DOUBLE);
INDArray result = Nd4j.tensorMmul(oneThroughSixty, oneThroughTwentyFour, new int[][] {{1, 0}, {0, 1}}); INDArray result = Nd4j.tensorMmul(oneThroughSixty, oneThroughTwentyFour, new int[][] {{1, 0}, {0, 1}});
@ -1043,7 +1164,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testNegativeShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeShape(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
INDArray reshaped = linspace.reshape(-1, 2); INDArray reshaped = linspace.reshape(-1, 2);
assertArrayEquals(new long[] {2, 2}, reshaped.shape()); assertArrayEquals(new long[] {2, 2}, reshaped.shape());
@ -1055,7 +1178,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGetColumnGetRow() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnGetRow(Nd4jBackend backend) {
INDArray row = Nd4j.ones(1, 5); INDArray row = Nd4j.ones(1, 5);
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
INDArray col = row.getColumn(i); INDArray col = row.getColumn(i);
@ -1070,7 +1195,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testDupAndDupWithOrder() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDupAndDupWithOrder(Nd4jBackend backend) {
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE); List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
int count = 0; int count = 0;
for (Pair<INDArray, String> pair : testInputs) { for (Pair<INDArray, String> pair : testInputs) {
@ -1092,7 +1219,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testToOffsetZeroCopy() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToOffsetZeroCopy(Nd4jBackend backend) {
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE); List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
int cnt = 0; int cnt = 0;

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,18 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class Nd4jTestsComparisonC extends BaseNd4jTest { public class Nd4jTestsComparisonC extends BaseNd4jTestWithBackends {
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonC.class); private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonC.class);
public static final int SEED = 123; public static final int SEED = 123;
DataType initialType; DataType initialType = Nd4j.dataType();
public Nd4jTestsComparisonC(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
@ -73,7 +70,9 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest {
@Test @Test
public void testGemmWithOpsCommonsMath() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);

View File

@ -25,8 +25,9 @@ import org.apache.commons.math3.linear.RealMatrix;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -43,18 +44,14 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@RunWith(Parameterized.class)
public class Nd4jTestsComparisonFortran extends BaseNd4jTest { public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonFortran.class); private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonFortran.class);
public static final int SEED = 123; public static final int SEED = 123;
DataType initialType; DataType initialType = Nd4j.dataType();
public Nd4jTestsComparisonFortran(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
@ -75,7 +72,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testCrash() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCrash(Nd4jBackend backend) {
INDArray array3d = Nd4j.ones(1, 10, 10); INDArray array3d = Nd4j.ones(1, 10, 10);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0); Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 1); Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 1);
@ -85,7 +84,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testMmulWithOpsCommonsMath() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
@ -100,7 +101,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGemmWithOpsCommonsMath() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
@ -156,7 +159,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGemvApacheCommons() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemvApacheCommons(Nd4jBackend backend) {
int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8}; int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8};
int[] colsArr = new int[] {2, 1, 10, 2, 1, 10}; int[] colsArr = new int[] {2, 1, 10, 2, 1, 10};
@ -211,7 +216,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testAddSubtractWithOpsCommonsMath() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddSubtractWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
for (int i = 0; i < first.size(); i++) { for (int i = 0; i < first.size(); i++) {
@ -229,7 +236,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testMulDivOnCheckUtilMatrices() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMulDivOnCheckUtilMatrices(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
for (int i = 0; i < first.size(); i++) { for (int i = 0; i < first.size(); i++) {

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -36,18 +37,15 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class Nd4jTestsF extends BaseNd4jTest {
DataType initialType; public class Nd4jTestsF extends BaseNd4jTestWithBackends {
public Nd4jTestsF(Nd4jBackend backend) { DataType initialType = Nd4j.dataType();
super(backend);
this.initialType = Nd4j.dataType();
}
@Test @Test
public void testConcat3D_Vstack_F() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat3D_Vstack_F(Nd4jBackend backend) {
//Nd4j.getExecutioner().enableVerboseMode(true); //Nd4j.getExecutioner().enableVerboseMode(true);
//Nd4j.getExecutioner().enableDebugMode(true); //Nd4j.getExecutioner().enableDebugMode(true);
@ -79,7 +77,9 @@ public class Nd4jTestsF extends BaseNd4jTest {
@Test @Test
public void testSlice_1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice_1(Nd4jBackend backend) {
val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1); val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1);
val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1}); val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1});
val exp1 = Nd4j.create(new double[]{2, 4}, new int[] {2, 1}); val exp1 = Nd4j.create(new double[]{2, 4}, new int[] {2, 1});

View File

@ -22,8 +22,9 @@ package org.nd4j.linalg;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -34,15 +35,13 @@ import java.util.*;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@RunWith(Parameterized.class)
public class ShufflesTests extends BaseNd4jTest {
public ShufflesTests(Nd4jBackend backend) { public class ShufflesTests extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
public void testSimpleShuffle1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle1(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(10, 10); INDArray array = Nd4j.zeros(10, 10);
for (int x = 0; x < 10; x++) { for (int x = 0; x < 10; x++) {
array.getRow(x).assign(x); array.getRow(x).assign(x);
@ -64,7 +63,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testSimpleShuffle2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle2(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(10, 10); INDArray array = Nd4j.zeros(10, 10);
for (int x = 0; x < 10; x++) { for (int x = 0; x < 10; x++) {
array.getColumn(x).assign(x); array.getColumn(x).assign(x);
@ -79,7 +80,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testSimpleShuffle3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle3(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(11, 10); INDArray array = Nd4j.zeros(11, 10);
for (int x = 0; x < 11; x++) { for (int x = 0; x < 11; x++) {
array.getRow(x).assign(x); array.getRow(x).assign(x);
@ -95,7 +98,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testSymmetricShuffle1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle1(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10); INDArray features = Nd4j.zeros(10, 10);
INDArray labels = Nd4j.zeros(10, 3); INDArray labels = Nd4j.zeros(10, 3);
for (int x = 0; x < 10; x++) { for (int x = 0; x < 10; x++) {
@ -133,7 +138,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testSymmetricShuffle2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle2(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10, 20); INDArray features = Nd4j.zeros(10, 10, 20);
INDArray labels = Nd4j.zeros(10, 10, 3); INDArray labels = Nd4j.zeros(10, 10, 3);
@ -171,7 +178,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testSymmetricShuffle3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle3(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10, 20); INDArray features = Nd4j.zeros(10, 10, 20);
INDArray featuresMask = Nd4j.zeros(10, 20); INDArray featuresMask = Nd4j.zeros(10, 20);
INDArray labels = Nd4j.zeros(10, 10, 3); INDArray labels = Nd4j.zeros(10, 10, 3);
@ -236,7 +245,9 @@ public class ShufflesTests extends BaseNd4jTest {
* @throws Exception * @throws Exception
*/ */
@Test @Test
public void testHalfVectors1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testHalfVectors1(Nd4jBackend backend) {
int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20); int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20);
int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20); int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20);
@ -257,7 +268,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testInterleavedVector1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterleavedVector1(Nd4jBackend backend) {
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20); int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20);
int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20); int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20);
@ -278,7 +291,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testInterleavedVector3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterleavedVector3(Nd4jBackend backend) {
for (int e = 0; e < 1000; e++) { for (int e = 0; e < 1000; e++) {
int length = e + 256; //RandomUtils.nextInt(121, 2073); int length = e + 256; //RandomUtils.nextInt(121, 2073);
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(System.currentTimeMillis()), length); int[] array1 = ArrayUtil.buildInterleavedVector(new Random(System.currentTimeMillis()), length);

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.eigen.Eigen; import org.nd4j.linalg.eigen.Eigen;
@ -35,16 +36,11 @@ import org.nd4j.common.util.ArrayUtil;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@Slf4j @Slf4j
public class TestEigen extends BaseNd4jTest { public class TestEigen extends BaseNd4jTestWithBackends {
protected DataType initialType; protected DataType initialType = Nd4j.dataType();
public TestEigen(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
public void before() { public void before() {
@ -59,7 +55,9 @@ public class TestEigen extends BaseNd4jTest {
// test of functions added by Luke Czapla // test of functions added by Luke Czapla
// Compares solution of A x = L x to solution to A x = L B x when it is simple // Compares solution of A x = L x to solution to A x = L B x when it is simple
@Test @Test
public void test2Syev() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2Syev(Nd4jBackend backend) {
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
Nd4j.setDefaultDataTypes(dt, dt); Nd4j.setDefaultDataTypes(dt, dt);
@ -78,7 +76,9 @@ public class TestEigen extends BaseNd4jTest {
} }
@Test @Test
public void testSyev() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSyev(Nd4jBackend backend) {
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
//log.info("Datatype: {}", dt); //log.info("Datatype: {}", dt);
Nd4j.setDefaultDataTypes(dt, dt); Nd4j.setDefaultDataTypes(dt, dt);

View File

@ -24,23 +24,23 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.util.ArrayUtil; import org.nd4j.common.util.ArrayUtil;
@RunWith(Parameterized.class)
@Slf4j @Slf4j
public class ToStringTest extends BaseNd4jTest { public class ToStringTest extends BaseNd4jTestWithBackends {
public ToStringTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testToString() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToString(Nd4jBackend backend) throws Exception {
assertEquals("[ 1, 2, 3]", assertEquals("[ 1, 2, 3]",
Nd4j.createFromArray(1, 2, 3).toString()); Nd4j.createFromArray(1, 2, 3).toString());
@ -58,6 +58,8 @@ public class ToStringTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToStringScalars(){ public void testToStringScalars(){
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32}; DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"}; String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.activations;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.impl.ActivationCube; import org.nd4j.linalg.activations.impl.ActivationCube;
import org.nd4j.linalg.activations.impl.ActivationELU; import org.nd4j.linalg.activations.impl.ActivationELU;
import org.nd4j.linalg.activations.impl.ActivationGELU; import org.nd4j.linalg.activations.impl.ActivationGELU;
@ -55,12 +56,9 @@ import java.util.List;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestActivation extends BaseNd4jTest {
public TestActivation(Nd4jBackend backend) { public class TestActivation extends BaseNd4jTestWithBackends {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -79,7 +77,9 @@ public class TestActivation extends BaseNd4jTest {
} }
@Test @Test
public void testRelu(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRelu(Nd4jBackend backend){
Double[] max = {null, 6.0, 2.5, 5.0}; Double[] max = {null, 6.0, 2.5, 5.0};
Double[] threshold = {0.0, 0.0, 0.75, 0.2}; Double[] threshold = {0.0, 0.0, 0.75, 0.2};
@ -131,7 +131,9 @@ public class TestActivation extends BaseNd4jTest {
} }
@Test @Test
public void testJson() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJson(Nd4jBackend backend) throws Exception {
IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25), IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25),
new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(), new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(),

View File

@ -20,21 +20,20 @@
package org.nd4j.linalg.api; package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.nd4j.linalg.factory.Environment; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
public class TestBackend extends BaseNd4jTest { public class TestBackend extends BaseNd4jTestWithBackends {
public TestBackend(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void TestBuildInfo(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBuildInfo(Nd4jBackend backend){
System.out.println("Backend build info: " + backend.buildInfo()); System.out.println("Backend build info: " + backend.buildInfo());
} }
} }

View File

@ -20,18 +20,17 @@
package org.nd4j.linalg.api; package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Environment; import org.nd4j.linalg.factory.Environment;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
public class TestEnvironment extends BaseNd4jTest { public class TestEnvironment extends BaseNd4jTestWithBackends {
public TestEnvironment(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -39,7 +38,9 @@ public class TestEnvironment extends BaseNd4jTest {
} }
@Test @Test
public void testEnvironment(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEnvironment(Nd4jBackend backend){
Environment e = Nd4j.getEnvironment(); Environment e = Nd4j.getEnvironment();
System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion()); System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion());
System.out.println("CPU: " + e.isCPU()); System.out.println("CPU: " + e.isCPU());

View File

@ -26,7 +26,9 @@ import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -40,16 +42,12 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class TestNDArrayCreation extends BaseNd4jTest { public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
public TestNDArrayCreation(Nd4jBackend backend) {
super(backend);
}
@Test @Test
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") @ParameterizedTest
public void testBufferCreation() { @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBufferCreation(Nd4jBackend backend) {
DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2}); DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2});
Pointer pointer = dataBuffer.pointer(); Pointer pointer = dataBuffer.pointer();
FloatPointer floatPointer = new FloatPointer(pointer); FloatPointer floatPointer = new FloatPointer(pointer);
@ -69,6 +67,8 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test @Test
@Disabled @Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateNpy() throws Exception { public void testCreateNpy() throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile()); INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile());
assertEquals(2, arrCreate.size(0)); assertEquals(2, arrCreate.size(0));
@ -82,7 +82,9 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test @Test
@Disabled @Disabled
public void testCreateNpz() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateNpz(Nd4jBackend backend) throws Exception {
Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile()); Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile());
assertEquals(true, map.containsKey("x")); assertEquals(true, map.containsKey("x"));
assertEquals(true, map.containsKey("y")); assertEquals(true, map.containsKey("y"));
@ -100,8 +102,7 @@ public class TestNDArrayCreation extends BaseNd4jTest {
} }
@Test @Test
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public void testCreateNpy3(Nd4jBackend backend) throws Exception {
public void testCreateNpy3() throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile()); INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
assertEquals(8, arrCreate.length()); assertEquals(8, arrCreate.length());
assertEquals(3, arrCreate.rank()); assertEquals(3, arrCreate.rank());
@ -113,7 +114,7 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test @Test
@Disabled // this is endless test @Disabled // this is endless test
public void testEndlessAllocation() { public void testEndlessAllocation(Nd4jBackend backend) {
Nd4j.getEnvironment().setMaxSpecialMemory(1); Nd4j.getEnvironment().setMaxSpecialMemory(1);
while (true) { while (true) {
val arr = Nd4j.createUninitialized(DataType.FLOAT, 100000000); val arr = Nd4j.createUninitialized(DataType.FLOAT, 100000000);

View File

@ -21,24 +21,23 @@
package org.nd4j.linalg.api; package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil; import org.nd4j.common.util.ArrayUtil;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
public class TestNDArrayCreationUtil extends BaseNd4jTest { public class TestNDArrayCreationUtil extends BaseNd4jTestWithBackends {
public TestNDArrayCreationUtil(Nd4jBackend backend) {
super(backend);
}
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapes() { public void testShapes() {
long[] shape2d = {2, 3}; long[] shape2d = {2, 3};

View File

@ -21,20 +21,21 @@
package org.nd4j.linalg.api; package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
public class TestNamespaces extends BaseNd4jTest { public class TestNamespaces extends BaseNd4jTestWithBackends {
public TestNamespaces(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testBitwiseSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBitwiseSimple(Nd4jBackend backend){
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT); INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT); INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
@ -50,7 +51,9 @@ public class TestNamespaces extends BaseNd4jTest {
} }
@Test @Test
public void testMathSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMathSimple(Nd4jBackend backend) {
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1); INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1);
INDArray abs = Nd4j.math.abs(x); INDArray abs = Nd4j.math.abs(x);
// System.out.println(x); // System.out.println(x);
@ -65,7 +68,9 @@ public class TestNamespaces extends BaseNd4jTest {
} }
@Test @Test
public void testRandomSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomSimple(Nd4jBackend backend){
INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10); INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10);
// System.out.println(normal); // System.out.println(normal);
INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10); INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10);
@ -73,7 +78,9 @@ public class TestNamespaces extends BaseNd4jTest {
} }
@Test @Test
public void testNeuralNetworkSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNeuralNetworkSimple(Nd4jBackend backend){
INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10)); INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10));
// System.out.println(out); // System.out.println(out);
INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1); INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.blas; package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -31,15 +32,14 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class LapackTest extends BaseNd4jTest { public class LapackTest extends BaseNd4jTestWithBackends {
public LapackTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testQRSquare() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testQRSquare(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9}); INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9});
A = A.reshape('c', 3, 3); A = A.reshape('c', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape()); INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -57,7 +57,9 @@ public class LapackTest extends BaseNd4jTest {
} }
@Test @Test
public void testQRRect() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testQRRect(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
A = A.reshape('f', 4, 3); A = A.reshape('f', 4, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape()); INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -75,7 +77,9 @@ public class LapackTest extends BaseNd4jTest {
} }
@Test @Test
public void testCholeskyL() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCholeskyL(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,}); INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,});
A = A.reshape('c', 3, 3); A = A.reshape('c', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape()); INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -92,7 +96,9 @@ public class LapackTest extends BaseNd4jTest {
} }
@Test @Test
public void testCholeskyU() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCholeskyU(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,}); INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,});
A = A.reshape('f', 3, 3); A = A.reshape('f', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape()); INDArray O = Nd4j.create(A.dataType(), A.shape());

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.blas;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -35,14 +36,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class Level1Test extends BaseNd4jTest { public class Level1Test extends BaseNd4jTestWithBackends {
public Level1Test(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testDot() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDot(Nd4jBackend backend) {
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1); assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1);
@ -55,7 +55,9 @@ public class Level1Test extends BaseNd4jTest {
} }
@Test @Test
public void testAxpy() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAxpy(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray row = matrix.getRow(1); INDArray row = matrix.getRow(1);
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row); Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
@ -64,7 +66,9 @@ public class Level1Test extends BaseNd4jTest {
} }
@Test @Test
public void testAxpy2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAxpy2(Nd4jBackend backend) {
val rowX = Nd4j.create(new double[]{1, 2, 3, 4}); val rowX = Nd4j.create(new double[]{1, 2, 3, 4});
val rowY = Nd4j.create(new double[]{1, 2, 3, 4}); val rowY = Nd4j.create(new double[]{1, 2, 3, 4});
val exp = Nd4j.create(new double[]{3, 6, 9, 12}); val exp = Nd4j.create(new double[]{3, 6, 9, 12});

View File

@ -21,23 +21,23 @@
package org.nd4j.linalg.api.blas; package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class Level2Test extends BaseNd4jTest { public class Level2Test extends BaseNd4jTestWithBackends {
public Level2Test(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testGemv1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv1(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -51,7 +51,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv2(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -65,7 +67,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv3(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -79,7 +83,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv4() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv4(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -93,7 +99,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv5() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv5(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -109,7 +117,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv6() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv6(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -125,7 +135,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv7() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv7(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);

View File

@ -21,23 +21,23 @@
package org.nd4j.linalg.api.blas; package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class Level3Test extends BaseNd4jTest { public class Level3Test extends BaseNd4jTestWithBackends {
public Level3Test(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testGemm1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm1(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100); INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -47,7 +47,9 @@ public class Level3Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemm2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm2(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100); INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -57,7 +59,9 @@ public class Level3Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemm3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm3(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
@ -75,7 +79,9 @@ public class Level3Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemm4() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm4(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);
@ -92,7 +98,9 @@ public class Level3Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemm5() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm5(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
@ -106,7 +114,9 @@ public class Level3Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemm6() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm6(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.blas.params; package org.nd4j.linalg.api.blas.params;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,16 +34,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class ParamsTestsF extends BaseNd4jTest {
public class ParamsTestsF extends BaseNd4jTestWithBackends {
public ParamsTestsF(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testGemm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm (Nd4jBackend backend) {
INDArray a = Nd4j.create(2, 2); INDArray a = Nd4j.create(2, 2);
INDArray b = Nd4j.create(2, 3); INDArray b = Nd4j.create(2, 3);
INDArray c = Nd4j.create(2, 3); INDArray c = Nd4j.create(2, 3);

View File

@ -25,9 +25,10 @@ import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*; import org.bytedeco.javacpp.indexer.*;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -45,16 +46,15 @@ import java.nio.ByteOrder;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class DataBufferTests extends BaseNd4jTest {
public DataBufferTests(Nd4jBackend backend) { public class DataBufferTests extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
@Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657") @Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657")
public void testNoArgCreateBufferFromArray() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoArgCreateBufferFromArray(Nd4jBackend backend) {
//Tests here: //Tests here:
//1. Create from JVM array //1. Create from JVM array
@ -280,7 +280,9 @@ public class DataBufferTests extends BaseNd4jTest {
@Test @Test
public void testCreateTypedBuffer() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateTypedBuffer(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
@ -350,7 +352,9 @@ public class DataBufferTests extends BaseNd4jTest {
} }
@Test @Test
public void testAsBytes() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAsBytes(Nd4jBackend backend) {
INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1); INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1);
for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16, for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16,
@ -405,6 +409,8 @@ public class DataBufferTests extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEnsureLocation(){ public void testEnsureLocation(){
//https://github.com/eclipse/deeplearning4j/issues/8783 //https://github.com/eclipse/deeplearning4j/issues/8783
Nd4j.create(1); Nd4j.create(1);

View File

@ -23,9 +23,10 @@ package org.nd4j.linalg.api.buffer;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -33,13 +34,10 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@RunWith(Parameterized.class)
public class DataTypeValidationTests extends BaseNd4jTest {
DataType initialType;
public DataTypeValidationTests(Nd4jBackend backend) { public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
super(backend); DataType initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
public void setUp() { public void setUp() {
@ -48,7 +46,7 @@ public class DataTypeValidationTests extends BaseNd4jTest {
} }
@AfterEach @AfterEach
public void shutUp() { public void reset() {
Nd4j.setDataType(initialType); Nd4j.setDataType(initialType);
} }
@ -73,7 +71,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level1 blas * Testing level1 blas
*/ */
@Test() @Test()
public void testBlasValidation1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> { assertThrows(ND4JIllegalStateException.class,() -> {
INDArray x = Nd4j.create(10); INDArray x = Nd4j.create(10);
@ -90,7 +90,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level2 blas * Testing level2 blas
*/ */
@Test() @Test()
public void testBlasValidation2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation2(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> { assertThrows(RuntimeException.class,() -> {
INDArray a = Nd4j.create(100, 10); INDArray a = Nd4j.create(100, 10);
INDArray x = Nd4j.create(100); INDArray x = Nd4j.create(100);
@ -108,7 +110,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level3 blas * Testing level3 blas
*/ */
@Test() @Test()
public void testBlasValidation3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation3(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
INDArray x = Nd4j.create(100, 100); INDArray x = Nd4j.create(100, 100);

View File

@ -26,9 +26,10 @@ import org.bytedeco.javacpp.indexer.Indexer;
import org.junit.jupiter.api.*; import org.junit.jupiter.api.*;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -54,34 +55,31 @@ import static org.junit.jupiter.api.Assertions.*;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public class DoubleDataBufferTest extends BaseNd4jTest { public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
DataType initialType; DataType initialType = Nd4j.dataType();
public DoubleDataBufferTest(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
public void before() { public void before(Nd4jBackend backend) {
DataTypeUtil.setDTypeForContext(DataType.DOUBLE); DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
} }
@AfterEach @AfterEach
public void after() { public void after(Nd4jBackend backend) {
DataTypeUtil.setDTypeForContext(initialType); DataTypeUtil.setDTypeForContext(initialType);
} }
@Test @Test
public void testPointerCreation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointerCreation(Nd4jBackend backend) {
DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4); DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4);
Indexer indexer = DoubleIndexer.create(floatPointer); Indexer indexer = DoubleIndexer.create(floatPointer);
DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.DOUBLE, 4, indexer); DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.DOUBLE, 4, indexer);
@ -90,7 +88,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testGetSet() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetSet(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4}; double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
double[] d2 = d.asDouble(); double[] d2 = d.asDouble();
@ -101,6 +101,8 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization2() throws Exception { public void testSerialization2() throws Exception {
INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10), INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10),
// Nd4j.ones(5,10).getRow(2) // Nd4j.ones(5,10).getRow(2)
@ -129,6 +131,8 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization(@TempDir Path testDir) throws Exception { public void testSerialization(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
DataBuffer buf = Nd4j.createBuffer(5); DataBuffer buf = Nd4j.createBuffer(5);
@ -151,7 +155,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testDup() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDup(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4}; double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
DataBuffer d2 = d.dup(); DataBuffer d2 = d.dup();
@ -161,7 +167,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testPut() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPut(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4}; double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
d.put(0, 0.0); d.put(0, 0.0);
@ -172,7 +180,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testGetRange() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
double[] get = buffer.getDoublesAt(0, 3); double[] get = buffer.getDoublesAt(0, 3);
double[] data = new double[] {1, 2, 3}; double[] data = new double[] {1, 2, 3};
@ -187,7 +197,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testGetOffsetRange() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetOffsetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
double[] get = buffer.getDoublesAt(1, 3); double[] get = buffer.getDoublesAt(1, 3);
double[] data = new double[] {2, 3, 4}; double[] data = new double[] {2, 3, 4};
@ -202,7 +214,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testAssign() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
DataBuffer one = Nd4j.createBuffer(new double[] {1}); DataBuffer one = Nd4j.createBuffer(new double[] {1});
DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3}); DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3});
@ -213,7 +227,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testOffset() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffset(Nd4jBackend backend) {
DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2); DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2);
assertEquals(2, create.length()); assertEquals(2, create.length());
assertEquals(0, create.offset()); assertEquals(0, create.offset());
@ -223,7 +239,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testReallocation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4}); DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity()); assertEquals(4, buffer.capacity());
double[] old = buffer.asDouble(); double[] old = buffer.asDouble();
@ -233,7 +251,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testReallocationWorkspace() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
@ -250,6 +270,8 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddressPointer(){ public void testAddressPointer(){
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){ if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
return; return;

View File

@ -27,7 +27,9 @@ import org.bytedeco.javacpp.indexer.Indexer;
import org.junit.jupiter.api.*; import org.junit.jupiter.api.*;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -54,14 +56,9 @@ import static org.junit.jupiter.api.Assertions.*;
* @author Adam Gibson * @author Adam Gibson
*/ */
@Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") @Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public class FloatDataBufferTest extends BaseNd4jTest { public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
DataType initialType; DataType initialType = Nd4j.dataType();
public FloatDataBufferTest(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
public void before() { public void before() {
@ -76,7 +73,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testPointerCreation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointerCreation(Nd4jBackend backend) {
FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4); FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4);
Indexer indexer = FloatIndexer.create(floatPointer); Indexer indexer = FloatIndexer.create(floatPointer);
DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.FLOAT, 4, indexer); DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.FLOAT, 4, indexer);
@ -85,7 +84,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testGetSet() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetSet(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4}; float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
float[] d2 = d.asFloat(); float[] d2 = d.asFloat();
@ -96,7 +97,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testSerialization(@TempDir Path tempDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization(@TempDir Path tempDir,Nd4jBackend backend) throws Exception {
File dir = tempDir.toFile(); File dir = tempDir.toFile();
DataBuffer buf = Nd4j.createBuffer(5); DataBuffer buf = Nd4j.createBuffer(5);
String fileName = "buf.ser"; String fileName = "buf.ser";
@ -117,7 +120,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testDup() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDup(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4}; float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
DataBuffer d2 = d.dup(); DataBuffer d2 = d.dup();
@ -125,7 +130,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testToNio() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToNio(Nd4jBackend backend) {
DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT); DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT);
assertEquals(4, buff.length()); assertEquals(4, buff.length());
if (buff.allocationMode() == DataBuffer.AllocationMode.HEAP) if (buff.allocationMode() == DataBuffer.AllocationMode.HEAP)
@ -137,7 +144,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testPut() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPut(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4}; float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
d.put(0, 0.0); d.put(0, 0.0);
@ -148,7 +157,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testGetRange() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(0, 3); float[] get = buffer.getFloatsAt(0, 3);
float[] data = new float[] {1, 2, 3}; float[] data = new float[] {1, 2, 3};
@ -164,7 +175,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testGetOffsetRange() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetOffsetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(1, 3); float[] get = buffer.getFloatsAt(1, 3);
float[] data = new float[] {2, 3, 4}; float[] data = new float[] {2, 3, 4};
@ -181,7 +194,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testAsBytes() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAsBytes(Nd4jBackend backend) {
INDArray arr = Nd4j.create(5); INDArray arr = Nd4j.create(5);
byte[] d = arr.data().asBytes(); byte[] d = arr.data().asBytes();
assertEquals(4 * 5, d.length,getFailureMessage()); assertEquals(4 * 5, d.length,getFailureMessage());
@ -191,7 +206,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testAssign() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
DataBuffer one = Nd4j.createBuffer(new double[] {1}); DataBuffer one = Nd4j.createBuffer(new double[] {1});
DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3}); DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3});
@ -201,7 +218,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testReadWrite() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWrite(Nd4jBackend backend) throws Exception {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos); DataOutputStream dos = new DataOutputStream(bos);
@ -215,7 +234,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testOffset() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffset(Nd4jBackend backend) {
DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2); DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2);
assertEquals(2, create.length()); assertEquals(2, create.length());
assertEquals(0, create.offset()); assertEquals(0, create.offset());
@ -225,7 +246,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testReallocation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4}); DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity()); assertEquals(4, buffer.capacity());
float[] old = buffer.asFloat(); float[] old = buffer.asFloat();
@ -236,7 +259,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testReallocationWorkspace() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
@ -253,7 +278,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testAddressPointer(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddressPointer(Nd4jBackend backend){
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){ if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
return; return;
} }

View File

@ -23,7 +23,9 @@ package org.nd4j.linalg.api.buffer;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy; import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
@ -37,13 +39,12 @@ import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class IntDataBufferTests extends BaseNd4jTest { public class IntDataBufferTests extends BaseNd4jTestWithBackends {
public IntDataBufferTests(Nd4jBackend backend) {
super(backend);
}
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasicSerde1() throws Exception { public void testBasicSerde1() throws Exception {
@ -82,7 +83,9 @@ public class IntDataBufferTests extends BaseNd4jTest {
*/ */
@Test @Test
public void testReallocation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4}); DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity()); assertEquals(4, buffer.capacity());
buffer.reallocate(6); buffer.reallocate(6);
@ -94,7 +97,9 @@ public class IntDataBufferTests extends BaseNd4jTest {
} }
@Test @Test
public void testReallocationWorkspace() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing; package org.nd4j.linalg.api.indexing;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -37,17 +38,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class) public class IndexingTests extends BaseNd4jTestWithBackends {
public class IndexingTests extends BaseNd4jTest {
public IndexingTests(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testINDArrayIndexingEqualToRank() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingEqualToRank(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
{0,1,2}, {0,1,2},
@ -62,7 +61,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test @Test
public void testINDArrayIndexingLessThanRankSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingLessThanRankSimple(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
{0}, {0},
@ -76,7 +77,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test @Test
public void testINDArrayIndexingLessThanRankFourDimension() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingLessThanRankFourDimension(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
{0},{1} {0},{1}
@ -89,7 +92,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testPutSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutSimple(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2); INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
{0},{1} {0},{1}
@ -101,7 +106,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testGetScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetScalar(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(NDArrayIndex.point(1)); INDArray d = arr.get(NDArrayIndex.point(1));
assertTrue(d.isScalar()); assertTrue(d.isScalar());
@ -110,14 +117,18 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testNewAxis() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis(Nd4jBackend backend) {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3}); INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1)); INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
// System.out.println(view); // System.out.println(view);
} }
@Test @Test
public void testVectorIndexing() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE);
int[] index = new int[] {5, 8, 9}; int[] index = new int[] {5, 8, 9};
INDArray columnsTest = x.getColumns(index); INDArray columnsTest = x.getColumns(index);
@ -129,7 +140,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testGetRowsColumnsMatrix() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowsColumnsMatrix(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6);
INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}}); INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}});
@ -147,7 +160,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test @Test
public void testSlicing() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlicing(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE); INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14}); INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14});
INDArray slice1Test = arange.slice(1); INDArray slice1Test = arange.slice(1);
@ -155,7 +170,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testArangeMul() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArangeMul(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE); INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = NDArrayIndex.interval(0, 2); INDArrayIndex index = NDArrayIndex.interval(0, 2);
INDArray get = arange.get(index, index); INDArray get = arange.get(index, index);
@ -167,7 +184,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testGetIndicesVector() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVector(Nd4jBackend backend) {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1); INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3}); INDArray test = Nd4j.create(new double[] {2, 3});
INDArray result = line.get(NDArrayIndex.point(0), NDArrayIndex.interval(1, 3)); INDArray result = line.get(NDArrayIndex.point(0), NDArrayIndex.interval(1, 3));
@ -175,7 +194,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testGetIndicesVectorView() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVectorView(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5); INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5);
INDArray column = matrix.getColumn(0).reshape(1,5); INDArray column = matrix.getColumn(0).reshape(1,5);
INDArray test = Nd4j.create(new double[] {6, 11}); INDArray test = Nd4j.create(new double[] {6, 11});
@ -193,7 +214,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void test2dGetPoint(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2dGetPoint(Nd4jBackend backend){
INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4); INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4);
for( int i=0; i<3; i++ ){ for( int i=0; i<3; i++ ){
INDArray exp = Nd4j.create(new double[]{i*4+1, i*4+2, i*4+3, i*4+4}); INDArray exp = Nd4j.create(new double[]{i*4+1, i*4+2, i*4+3, i*4+4});
@ -206,7 +229,7 @@ public class IndexingTests extends BaseNd4jTest {
assertEquals(exp, get); assertEquals(exp, get);
} }
for( int i=0; i<4; i++ ){ for( int i = 0; i < 4; i++) {
INDArray exp = Nd4j.create(new double[]{1+i, 5+i, 9+i}); INDArray exp = Nd4j.create(new double[]{1+i, 5+i, 9+i});
INDArray col = arr.getColumn(i); INDArray col = arr.getColumn(i);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(i)); INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(i));

View File

@ -21,10 +21,11 @@
package org.nd4j.linalg.api.indexing; package org.nd4j.linalg.api.indexing;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -49,16 +50,15 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class IndexingTestsC extends BaseNd4jTest { public class IndexingTestsC extends BaseNd4jTestWithBackends {
public IndexingTestsC(Nd4jBackend backend) {
super(backend);
}
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeBounds() { public void testNegativeBounds() {
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5); INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1)); INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
@ -71,6 +71,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis() { public void testNewAxis() {
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all()); INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
@ -80,6 +82,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void broadcastBug() { public void broadcastBug() {
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2}); INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0)); final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
@ -91,6 +95,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIntervalsIn3D() { public void testIntervalsIn3D() {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE); INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2); INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
@ -100,6 +106,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSmallInterval() { public void testSmallInterval() {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE); INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2); INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
@ -109,6 +117,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxisAndInterval() { public void testAllWithNewAxisAndInterval() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3); INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
@ -118,6 +128,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxisInMiddle() { public void testAllWithNewAxisInMiddle() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3); INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
@ -127,6 +139,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxis() { public void testAllWithNewAxis() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray get = arr.get(newAxis(), all(), point(1)); INDArray get = arr.get(newAxis(), all(), point(1));
@ -137,6 +151,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingWithMmul() { public void testIndexingWithMmul() {
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1); INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
@ -148,6 +164,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointPointInterval() { public void testPointPointInterval() {
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3); INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3)); INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
@ -157,6 +175,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIntervalLowerBound() { public void testIntervalLowerBound() {
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2)); INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
@ -168,6 +188,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetPointRowVector() { public void testGetPointRowVector() {
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1); INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
@ -178,6 +200,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedIndexVector() { public void testSpecifiedIndexVector() {
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4); INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
@ -195,6 +219,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowIndexing() { public void testPutRowIndexing() {
INDArray arr = Nd4j.ones(1, 10); INDArray arr = Nd4j.ones(1, 10);
INDArray row = Nd4j.create(1, 10); INDArray row = Nd4j.create(1, 10);
@ -205,6 +231,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing2() { public void testVectorIndexing2() {
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true)); INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
INDArray assertion = Nd4j.create(new double[] {2, 4}); INDArray assertion = Nd4j.create(new double[] {2, 4});
@ -220,6 +248,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffsetsC() { public void testOffsetsC() {
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
assertEquals(3, NDArrayIndex.offset(arr, 1, 1)); assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
@ -236,6 +266,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexFor() { public void testIndexFor() {
long[] shape = {1, 2}; long[] shape = {1, 2};
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape); INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
@ -245,6 +277,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetScalar() { public void testGetScalar() {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(point(1)); INDArray d = arr.get(point(1));
@ -254,6 +288,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing() { public void testVectorIndexing() {
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1); INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5}); INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
@ -262,6 +298,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeIndices() { public void testNegativeIndices() {
INDArray test = Nd4j.create(10, 10, 10); INDArray test = Nd4j.create(10, 10, 10);
test.putScalar(new int[] {0, 0, -1}, 1.0); test.putScalar(new int[] {0, 0, -1}, 1.0);
@ -269,6 +307,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndices2d() { public void testGetIndices2d() {
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2); INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
INDArray firstRow = twoByTwo.getRow(0); INDArray firstRow = twoByTwo.getRow(0);
@ -287,6 +327,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRow() { public void testGetRow() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5); INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
@ -304,6 +346,8 @@ public class IndexingTestsC extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowEdgeCase() { public void testGetRowEdgeCase() {
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1); INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
INDArray get = rowVec.getRow(0); //Returning shape [1,1] INDArray get = rowVec.getRow(0); //Returning shape [1,1]
@ -313,6 +357,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnEdgeCase() { public void testGetColumnEdgeCase() {
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose(); INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray get = colVec.getColumn(0); //Returning shape [1,1] INDArray get = colVec.getColumn(0); //Returning shape [1,1]
@ -322,6 +368,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatColumns() { public void testConcatColumns() {
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE); INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE); INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
@ -331,6 +379,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVector() { public void testGetIndicesVector() {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1); INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3}); INDArray test = Nd4j.create(new double[] {2, 3});
@ -339,6 +389,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArangeMul() { public void testArangeMul() {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE); INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = interval(0, 2); INDArrayIndex index = interval(0, 2);
@ -350,6 +402,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingThorough(){ public void testIndexingThorough(){
long[] fullShape = {3,4,5,6,7}; long[] fullShape = {3,4,5,6,7};
@ -550,6 +604,8 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void debugging(){ public void debugging(){
long[] inShape = {3,4}; long[] inShape = {3,4};
INDArrayIndex[] indexes = new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(1, 2, 4)}; INDArrayIndex[] indexes = new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(1, 2, 4)};

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.resolve; package org.nd4j.linalg.api.indexing.resolve;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -36,15 +37,14 @@ import static org.junit.jupiter.api.Assertions.*;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class NDArrayIndexResolveTests extends BaseNd4jTest {
public NDArrayIndexResolveTests(Nd4jBackend backend) { public class NDArrayIndexResolveTests extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
public void testResolvePoint() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResolvePoint(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
INDArrayIndex[] test = NDArrayIndex.resolve(arr.shape(), NDArrayIndex.point(1)); INDArrayIndex[] test = NDArrayIndex.resolve(arr.shape(), NDArrayIndex.point(1));
INDArrayIndex[] assertion = {NDArrayIndex.point(1), NDArrayIndex.all()}; INDArrayIndex[] assertion = {NDArrayIndex.point(1), NDArrayIndex.all()};
@ -59,6 +59,8 @@ public class NDArrayIndexResolveTests extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResolvePointVector() { public void testResolvePointVector() {
INDArray arr = Nd4j.linspace(1, 4, 4); INDArray arr = Nd4j.linspace(1, 4, 4);
INDArrayIndex[] getPoint = {NDArrayIndex.point(1)}; INDArrayIndex[] getPoint = {NDArrayIndex.point(1)};

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.shape; package org.nd4j.linalg.api.indexing.shape;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.Indices; import org.nd4j.linalg.indexing.Indices;
@ -34,19 +35,15 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class IndexShapeTests extends BaseNd4jTest {
public IndexShapeTests(Nd4jBackend backend) {
super(backend);
}
public class IndexShapeTests extends BaseNd4jTestWithBackends {
private int[] shape = {1, 1, 2, 1, 3, 4, 5, 1}; private int[] shape = {1, 1, 2, 1, 3, 4, 5, 1};
@Test @Test
public void testSinglePoint() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSinglePoint(Nd4jBackend backend) {
/* /*
Assumes all indexes are filled out. Assumes all indexes are filled out.
Test simple general point case Test simple general point case
@ -77,7 +74,9 @@ public class IndexShapeTests extends BaseNd4jTest {
} }
@Test @Test
public void testInterval() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterval(Nd4jBackend backend) {
int[] basicAssertion = {1, 1, 1, 1, 3, 1, 2, 1}; int[] basicAssertion = {1, 1, 1, 1, 3, 1, 2, 1};
INDArrayIndex[] basicTest = {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 1), INDArrayIndex[] basicTest = {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 1),
NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(1, 2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(1, 2),
@ -88,7 +87,9 @@ public class IndexShapeTests extends BaseNd4jTest {
@Test @Test
public void testNewAxis() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis(Nd4jBackend backend) {
//normal prepend //normal prepend
int[] prependAssertion = {1, 1, 1, 1, 2, 1, 3, 4, 5, 1}; int[] prependAssertion = {1, 1, 1, 1, 2, 1, 3, 4, 5, 1};
INDArrayIndex[] prependTest = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(), INDArrayIndex[] prependTest = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(),

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.shape; package org.nd4j.linalg.api.indexing.shape;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.Indices; import org.nd4j.linalg.indexing.Indices;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -33,25 +34,26 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class IndexShapeTests2d extends BaseNd4jTest {
public IndexShapeTests2d(Nd4jBackend backend) { public class IndexShapeTests2d extends BaseNd4jTestWithBackends {
super(backend);
}
private long[] shape = {3, 2}; private long[] shape = {3, 2};
@Test @Test
public void test2dCases() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2dCases(Nd4jBackend backend) {
assertArrayEquals(new long[] {1, 2}, Indices.shape(shape, NDArrayIndex.point(1))); assertArrayEquals(new long[] {1, 2}, Indices.shape(shape, NDArrayIndex.point(1)));
assertArrayEquals(new long[] {3, 1}, assertArrayEquals(new long[] {3, 1},
Indices.shape(shape, NDArrayIndex.all(), NDArrayIndex.point(1))); Indices.shape(shape, NDArrayIndex.all(), NDArrayIndex.point(1)));
} }
@Test @Test
public void testNewAxis2d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis2d(Nd4jBackend backend) {
assertArrayEquals(new long[] {1, 3, 2}, Indices.shape(shape, assertArrayEquals(new long[] {1, 3, 2}, Indices.shape(shape,
NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all())); NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all()));
assertArrayEquals(new long[] {3, 1, 2}, Indices.shape(shape, assertArrayEquals(new long[] {3, 1, 2}, Indices.shape(shape,

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.iterator;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,15 +34,14 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class NDIndexIteratorTest extends BaseNd4jTest {
public NDIndexIteratorTest(Nd4jBackend backend) { public class NDIndexIteratorTest extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
public void testIterate() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIterate(Nd4jBackend backend) {
val shapeIter = new NdIndexIterator(2, 2); val shapeIter = new NdIndexIterator(2, 2);
val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},}; val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},};

View File

@ -28,9 +28,10 @@ import org.apache.commons.lang3.ArrayUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -45,18 +46,15 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class TestNdArrReadWriteTxt extends BaseNd4jTest {
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
public TestNdArrReadWriteTxt(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void compareAfterWrite(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
int [] ranksToCheck = new int[] {0,1,2,3,4}; int [] ranksToCheck = new int[] {0,1,2,3,4};
for (int i=0; i<ranksToCheck.length;i++) { for (int i = 0; i < ranksToCheck.length; i++) {
// log.info("Checking read write arrays with rank " + ranksToCheck[i]); // log.info("Checking read write arrays with rank " + ranksToCheck[i]);
compareArrays(ranksToCheck[i],ordering(), testDir); compareArrays(ranksToCheck[i],ordering(), testDir);
} }
@ -84,7 +82,9 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTest {
} }
@Test @Test
public void testNd4jReadWriteText(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jReadWriteText(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
int count = 0; int count = 0;

View File

@ -25,9 +25,10 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import java.nio.file.Path; import java.nio.file.Path;
@ -35,17 +36,14 @@ import java.nio.file.Path;
import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays; import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class TestNdArrReadWriteTxtC extends BaseNd4jTest {
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
public TestNdArrReadWriteTxtC(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void compareAfterWrite(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4}; int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
for (int i = 0; i < ranksToCheck.length; i++) { for (int i = 0; i < ranksToCheck.length; i++) {
log.info("Checking read write arrays with rank " + ranksToCheck[i]); log.info("Checking read write arrays with rank " + ranksToCheck[i]);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.ndarray; package org.nd4j.linalg.api.ndarray;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -32,16 +33,14 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestSerialization extends BaseNd4jTest {
public TestSerialization(Nd4jBackend backend) { public class TestSerialization extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10); INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10); INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -71,7 +70,9 @@ public class TestSerialization extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationFullArrayJava() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10); INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10); INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -102,7 +103,9 @@ public class TestSerialization extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationOnViewsNd4jWriteRead() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10); INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10); INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -138,7 +141,9 @@ public class TestSerialization extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationOnViewsJava() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10); INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10); INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);

View File

@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -39,15 +40,11 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class TestSerializationDoubleToFloat extends BaseNd4jTest {
DataType initialType; public class TestSerializationDoubleToFloat extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
public TestSerializationDoubleToFloat(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@AfterEach @AfterEach
public void after() { public void after() {
@ -55,7 +52,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 4; int length = 4;
//WRITE OUT A DOUBLE ARRAY //WRITE OUT A DOUBLE ARRAY
@ -93,7 +92,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationFullArrayJava() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE); DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
@ -123,7 +124,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationOnViewsNd4jWriteRead() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE); DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
@ -153,7 +156,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationOnViewsJava() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE); DataTypeUtil.setDTypeForContext(DataType.DOUBLE);

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.ndarray;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -37,15 +38,11 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class TestSerializationFloatToDouble extends BaseNd4jTest {
DataType initialType; public class TestSerializationFloatToDouble extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
public TestSerializationFloatToDouble(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@AfterEach @AfterEach
public void after() { public void after() {
@ -53,7 +50,9 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
//WRITE OUT A FLOAT ARRAY //WRITE OUT A FLOAT ARRAY
@ -86,6 +85,8 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava() throws Exception { public void testSerializationFullArrayJava() throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);
@ -117,6 +118,8 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead() throws Exception { public void testSerializationOnViewsNd4jWriteRead() throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);
@ -147,6 +150,8 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava() throws Exception { public void testSerializationOnViewsJava() throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.rng; package org.nd4j.linalg.api.rng;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,14 +34,13 @@ import static org.junit.jupiter.api.Assertions.*;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class RngTests extends BaseNd4jTest { public class RngTests extends BaseNd4jTestWithBackends {
public RngTests(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testRngConstitency() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRngConstitency(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
INDArray arr = Nd4j.rand(1, 5); INDArray arr = Nd4j.rand(1, 5);
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
@ -49,7 +49,9 @@ public class RngTests extends BaseNd4jTest {
} }
@Test @Test
public void testRandomWithOrder() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomWithOrder(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -105,7 +107,9 @@ public class RngTests extends BaseNd4jTest {
} }
@Test @Test
public void testRandomBinomial() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomBinomial(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
//silly tests. Just increasing the usage for randomBinomial to stop compiler warnings. //silly tests. Just increasing the usage for randomBinomial to stop compiler warnings.
INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3); INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3);

View File

@ -23,9 +23,10 @@ package org.nd4j.linalg.api.string;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.Assert; import org.junit.Assert;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,22 +36,23 @@ import org.nd4j.linalg.string.NDArrayStrings;
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class TestFormatting extends BaseNd4jTest {
public TestFormatting(Nd4jBackend backend) { public class TestFormatting extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
public void testTwoByTwo() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTwoByTwo(Nd4jBackend backend) {
INDArray arr = Nd4j.create(2, 2, 2, 2); INDArray arr = Nd4j.create(2, 2, 2, 2);
System.out.println(new NDArrayStrings().format(arr)); System.out.println(new NDArrayStrings().format(arr));
} }
@Test @Test
public void testNd4jArrayString() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jArrayString(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[]{1f, 20000000f, 40.838383f, 3f}, new int[]{2, 2}); INDArray arr = Nd4j.create(new float[]{1f, 20000000f, 40.838383f, 3f}, new int[]{2, 2});
@ -71,7 +73,9 @@ public class TestFormatting extends BaseNd4jTest {
} }
@Test @Test
public void testRange() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRange(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[][]{ INDArray arr = Nd4j.create(new double[][]{
{-1,0,1,0}, {-1,0,1,0},
{-0.1, 0.1, -10, 10}, {-0.1, 0.1, -10, 10},

Some files were not shown because too many files have changed in this diff Show More