Migrate parameterized tests to junit 5

master
agibsonccc 2021-03-16 22:08:35 +09:00
parent 82bdcc21d2
commit 3c6014271e
260 changed files with 10787 additions and 6270 deletions

View File

@ -37,8 +37,10 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -47,13 +49,14 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays; import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.conf.ConvolutionMode.Same; import static org.deeplearning4j.nn.conf.ConvolutionMode.Same;
import static org.deeplearning4j.nn.conf.ConvolutionMode.Truncate; import static org.deeplearning4j.nn.conf.ConvolutionMode.Truncate;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class)
@DisplayName("Cnn Gradient Check Test") @DisplayName("Cnn Gradient Check Test")
class CNNGradientCheckTest extends BaseDL4JTest { class CNNGradientCheckTest extends BaseDL4JTest {
@ -71,15 +74,10 @@ class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
private CNN2DFormat format;
public CNNGradientCheckTest(CNN2DFormat format) {
this.format = format;
}
@Parameterized.Parameters(name = "{0}") public static Stream<Arguments> params() {
public static Object[] params() { return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of);
return CNN2DFormat.values();
} }
@Override @Override
@ -89,9 +87,11 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Gradient CNNMLN") @DisplayName("Test Gradient CNNMLN")
void testGradientCNNMLN() { @ParameterizedTest
@MethodSource("#params")
public void testGradientCNNMLN(CNN2DFormat format) {
if (// Only test NCHW due to flat input format... if (// Only test NCHW due to flat input format...
this.format != CNN2DFormat.NCHW) format != CNN2DFormat.NCHW)
return; return;
// Parameterized test, testing combinations of: // Parameterized test, testing combinations of:
// (a) activation function // (a) activation function
@ -146,9 +146,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Gradient CNNL 1 L 2 MLN") @DisplayName("Test Gradient CNNL 1 L 2 MLN")
void testGradientCNNL1L2MLN() { void testGradientCNNL1L2MLN(CNN2DFormat format) {
if (// Only test NCHW due to flat input format... if (// Only test NCHW due to flat input format...
this.format != CNN2DFormat.NCHW) format != CNN2DFormat.NCHW)
return; return;
// Parameterized test, testing combinations of: // Parameterized test, testing combinations of:
// (a) activation function // (a) activation function
@ -245,7 +245,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Space To Batch") @DisplayName("Test Cnn With Space To Batch")
void testCnnWithSpaceToBatch() { @ParameterizedTest
@MethodSource("#params")
public void testCnnWithSpaceToBatch(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 2, 4 }; int[] minibatchSizes = { 2, 4 };
@ -289,7 +291,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Upsampling") @DisplayName("Test Cnn With Upsampling")
void testCnnWithUpsampling() { @ParameterizedTest
@MethodSource("#params")
void testCnnWithUpsampling(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
@ -323,7 +327,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Subsampling") @DisplayName("Test Cnn With Subsampling")
void testCnnWithSubsampling() { @ParameterizedTest
@MethodSource("#params")
void testCnnWithSubsampling(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
@ -365,7 +371,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn With Subsampling V 2") @DisplayName("Test Cnn With Subsampling V 2")
void testCnnWithSubsamplingV2() { @ParameterizedTest
@MethodSource("#params")
void testCnnWithSubsamplingV2(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
@ -403,7 +411,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Locally Connected 2 D") @DisplayName("Test Cnn Locally Connected 2 D")
void testCnnLocallyConnected2D() { @ParameterizedTest
@MethodSource("#params")
void testCnnLocallyConnected2D(CNN2DFormat format) {
int nOut = 3; int nOut = 3;
int width = 5; int width = 5;
int height = 5; int height = 5;
@ -433,7 +443,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Multi Layer") @DisplayName("Test Cnn Multi Layer")
void testCnnMultiLayer() { @ParameterizedTest
@MethodSource("#params")
void testCnnMultiLayer(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = { 1, 2, 5 }; int[] minibatchSizes = { 1, 2, 5 };
int width = 5; int width = 5;
@ -473,7 +485,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Same Padding Mode") @DisplayName("Test Cnn Same Padding Mode")
void testCnnSamePaddingMode() { @ParameterizedTest
@MethodSource("#params")
void testCnnSamePaddingMode(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = { 1, 3, 3, 2, 1, 2 }; int[] minibatchSizes = { 1, 3, 3, 2, 1, 2 };
// Same padding mode: insensitive to exact input size... // Same padding mode: insensitive to exact input size...
@ -507,7 +521,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Same Padding Mode Strided") @DisplayName("Test Cnn Same Padding Mode Strided")
void testCnnSamePaddingModeStrided() { @ParameterizedTest
@MethodSource("#params")
void testCnnSamePaddingModeStrided(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = { 1, 3 }; int[] minibatchSizes = { 1, 3 };
int width = 16; int width = 16;
@ -550,7 +566,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Zero Padding Layer") @DisplayName("Test Cnn Zero Padding Layer")
void testCnnZeroPaddingLayer() { @ParameterizedTest
@MethodSource("#params")
void testCnnZeroPaddingLayer(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 4; int nOut = 4;
int width = 6; int width = 6;
@ -596,7 +614,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Deconvolution 2 D") @DisplayName("Test Deconvolution 2 D")
void testDeconvolution2D() { @ParameterizedTest
@MethodSource("#params")
void testDeconvolution2D(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int[] minibatchSizes = new int[] { 1, 3, 3, 1, 3 }; int[] minibatchSizes = new int[] { 1, 3, 3, 1, 3 };
int[] kernelSizes = new int[] { 1, 1, 1, 3, 3 }; int[] kernelSizes = new int[] { 1, 1, 1, 3, 3 };
@ -641,7 +661,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Separable Conv 2 D") @DisplayName("Test Separable Conv 2 D")
void testSeparableConv2D() { @ParameterizedTest
@MethodSource("#params")
void testSeparableConv2D(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int width = 6; int width = 6;
int height = 6; int height = 6;
@ -686,7 +708,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cnn Dilated") @DisplayName("Test Cnn Dilated")
void testCnnDilated() { @ParameterizedTest
@MethodSource("#params")
void testCnnDilated(CNN2DFormat format) {
int nOut = 2; int nOut = 2;
int minibatchSize = 2; int minibatchSize = 2;
int width = 8; int width = 8;
@ -736,7 +760,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Cropping 2 D Layer") @DisplayName("Test Cropping 2 D Layer")
void testCropping2DLayer() { @ParameterizedTest
@MethodSource("#params")
void testCropping2DLayer(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nOut = 2; int nOut = 2;
int width = 12; int width = 12;
@ -780,7 +806,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Depthwise Conv 2 D") @DisplayName("Test Depthwise Conv 2 D")
void testDepthwiseConv2D() { @ParameterizedTest
@MethodSource("#params")
void testDepthwiseConv2D(CNN2DFormat format) {
int nIn = 3; int nIn = 3;
int depthMultiplier = 2; int depthMultiplier = 2;
int nOut = nIn * depthMultiplier; int nOut = nIn * depthMultiplier;

View File

@ -39,8 +39,10 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -55,26 +57,22 @@ import java.io.File;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.InputStream; import java.io.InputStream;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class YoloGradientCheckTests extends BaseDL4JTest { public class YoloGradientCheckTests extends BaseDL4JTest {
static { static {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
private CNN2DFormat format;
public YoloGradientCheckTests(CNN2DFormat format){
this.format = format;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return CNN2DFormat.values();
}
public static Stream<Arguments> params() {
return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of);
}
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
@ -82,7 +80,9 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
} }
@Test @Test
public void testYoloOutputLayer() { @ParameterizedTest
@MethodSource("#params")
public void testYoloOutputLayer(CNN2DFormat format) {
int depthIn = 2; int depthIn = 2;
int c = 3; int c = 3;
int b = 3; int b = 3;
@ -159,13 +159,13 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
} }
} }
private static INDArray yoloLabels(int mb, int c, int h, int w){ private static INDArray yoloLabels(int mb, int c, int h, int w) {
int labelDepth = 4 + c; int labelDepth = 4 + c;
INDArray labels = Nd4j.zeros(mb, labelDepth, h, w); INDArray labels = Nd4j.zeros(mb, labelDepth, h, w);
//put 1 object per minibatch, at positions (0,0), (1,1) etc. //put 1 object per minibatch, at positions (0,0), (1,1) etc.
//Positions for label boxes: (1,1) to (2,2), (2,2) to (4,4) etc //Positions for label boxes: (1,1) to (2,2), (2,2) to (4,4) etc
for( int i=0; i<mb; i++ ){ for( int i = 0; i < mb; i++) {
//Class labels //Class labels
labels.putScalar(i, 4 + i%c, i%h, i%w, 1); labels.putScalar(i, 4 + i%c, i%h, i%w, 1);
@ -181,7 +181,9 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
@Test @Test
public void yoloGradientCheckRealData(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("#params")
public void yoloGradientCheckRealData(@TempDir Path testDir,CNN2DFormat format) throws Exception {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
InputStream is1 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2007_009346.jpg").getInputStream(); InputStream is1 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2007_009346.jpg").getInputStream();
InputStream is2 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2007_009346.xml").getInputStream(); InputStream is2 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2007_009346.xml").getInputStream();

View File

@ -39,8 +39,10 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -48,24 +50,19 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class ConvDataFormatTests extends BaseDL4JTest { public class ConvDataFormatTests extends BaseDL4JTest {
private final DataType dataType; public static Stream<Arguments> params(){
return Arrays.asList(new DataType[]{DataType.FLOAT, DataType.DOUBLE}).stream().map(Arguments::of);
public ConvDataFormatTests(DataType dataType){
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return new DataType[]{DataType.FLOAT, DataType.DOUBLE};
} }
@Override @Override
@ -74,7 +71,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testConv2d() { @MethodSource("#params")
@ParameterizedTest
public void testConv2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -83,15 +82,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getConv2dNet(CNN2DFormat.NCHW, true, cm)) .net1(getConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getConv2dNet(CNN2DFormat.NCHW, false, cm)) .net2(getConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getConv2dNet(CNN2DFormat.NHWC, true, cm)) .net3(getConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getConv2dNet(CNN2DFormat.NHWC, false, cm)) .net4(getConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -107,7 +106,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testSubsampling2d() { @MethodSource("#params")
@ParameterizedTest
public void testSubsampling2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -116,15 +117,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm)) .net1(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm)) .net2(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm)) .net3(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm)) .net4(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -140,7 +141,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testDepthwiseConv2d() { @MethodSource("#params")
@ParameterizedTest
public void testDepthwiseConv2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -149,15 +152,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm)) .net1(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm)) .net2(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm)) .net3(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm)) .net4(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -173,7 +176,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testSeparableConv2d() { @MethodSource("#params")
@ParameterizedTest
public void testSeparableConv2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -182,15 +187,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm)) .net1(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm)) .net2(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm)) .net3(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm)) .net4(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -206,7 +211,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testDeconv2d() { @MethodSource("#params")
@ParameterizedTest
public void testDeconv2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -215,15 +222,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm)) .net1(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm)) .net2(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm)) .net3(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm)) .net4(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -239,7 +246,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testLRN() { @MethodSource("#params")
@ParameterizedTest
public void testLRN(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -248,15 +257,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getLrnLayer(CNN2DFormat.NCHW, true, cm)) .net1(getLrnLayer(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getLrnLayer(CNN2DFormat.NCHW, false, cm)) .net2(getLrnLayer(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getLrnLayer(CNN2DFormat.NHWC, true, cm)) .net3(getLrnLayer(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getLrnLayer(CNN2DFormat.NHWC, false, cm)) .net4(getLrnLayer(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -272,7 +281,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testZeroPaddingLayer(){ @MethodSource("#params")
@ParameterizedTest
public void testZeroPaddingLayer(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -280,15 +291,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getZeroPaddingNet(CNN2DFormat.NCHW, true)) .net1(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, true))
.net2(getZeroPaddingNet(CNN2DFormat.NCHW, false)) .net2(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, false))
.net3(getZeroPaddingNet(CNN2DFormat.NHWC, true)) .net3(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, true))
.net4(getZeroPaddingNet(CNN2DFormat.NHWC, false)) .net4(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -303,7 +314,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testCropping2DLayer(){ @MethodSource("#params")
@ParameterizedTest
public void testCropping2DLayer(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -311,15 +324,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getCropping2dNet(CNN2DFormat.NCHW, true)) .net1(getCropping2dNet(dataType,CNN2DFormat.NCHW, true))
.net2(getCropping2dNet(CNN2DFormat.NCHW, false)) .net2(getCropping2dNet(dataType,CNN2DFormat.NCHW, false))
.net3(getCropping2dNet(CNN2DFormat.NHWC, true)) .net3(getCropping2dNet(dataType,CNN2DFormat.NHWC, true))
.net4(getCropping2dNet(CNN2DFormat.NHWC, false)) .net4(getCropping2dNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -334,7 +347,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testUpsampling2d(){ @MethodSource("#params")
@ParameterizedTest
public void testUpsampling2d(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -342,15 +357,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getUpsamplingNet(CNN2DFormat.NCHW, true)) .net1(getUpsamplingNet(dataType,CNN2DFormat.NCHW, true))
.net2(getUpsamplingNet(CNN2DFormat.NCHW, false)) .net2(getUpsamplingNet(dataType,CNN2DFormat.NCHW, false))
.net3(getUpsamplingNet(CNN2DFormat.NHWC, true)) .net3(getUpsamplingNet(dataType,CNN2DFormat.NHWC, true))
.net4(getUpsamplingNet(CNN2DFormat.NHWC, false)) .net4(getUpsamplingNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -365,7 +380,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testBatchNormNet(){ @MethodSource("#params")
@ParameterizedTest
public void testBatchNormNet(DataType dataType) {
try { try {
for(boolean useLogStd : new boolean[]{true, false}) { for(boolean useLogStd : new boolean[]{true, false}) {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
@ -374,15 +391,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std"); String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std");
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true)) .net1(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, true))
.net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false)) .net2(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, false))
.net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true)) .net3(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, true))
.net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false)) .net4(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -398,7 +415,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testCnnLossLayer() { @MethodSource("#params")
@ParameterizedTest
public void testCnnLossLayer(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -406,8 +425,8 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3); INDArray labelsNHWC = TestUtils.randomOneHot(dataType,2*6*6, 3);
labelsNHWC = labelsNHWC.reshape(2,6,6,3); labelsNHWC = labelsNHWC.reshape(2,6,6,3);
INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup(); INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup();
@ -434,7 +453,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testSpaceToDepthNet(){ @MethodSource("#params")
@ParameterizedTest
public void testSpaceToDepthNet(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -442,15 +463,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true)) .net1(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, true))
.net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false)) .net2(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, false))
.net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true)) .net3(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, true))
.net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false)) .net4(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -465,7 +486,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testSpaceToBatchNet(){ @MethodSource("#params")
@ParameterizedTest
public void testSpaceToBatchNet(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -473,15 +496,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers"; String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 16, 16);
INDArray labels = TestUtils.randomOneHot(8, 10); INDArray labels = TestUtils.randomOneHot(8, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true)) .net1(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, true))
.net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false)) .net2(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, false))
.net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true)) .net3(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, true))
.net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false)) .net4(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -496,7 +519,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testLocallyConnected() { @MethodSource("#params")
@ParameterizedTest
public void testLocallyConnected(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -505,15 +530,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm)) .net1(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm)) .net2(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm)) .net3(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm)) .net4(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -530,7 +555,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
@Test @Test
public void testGlobalPooling() { @MethodSource("#params")
@ParameterizedTest
public void testGlobalPooling(DataType dataType) {
try { try {
for (boolean helpers : new boolean[]{false, true}) { for (boolean helpers : new boolean[]{false, true}) {
for (PoolingType pt : PoolingType.values()) { for (PoolingType pt : PoolingType.values()) {
@ -539,15 +566,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")"; String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
System.out.println(" --- " + msg + " ---"); System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10); INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder() TestCase tc = TestCase.builder()
.msg(msg) .msg(msg)
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true)) .net1(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false)) .net2(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true)) .net3(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false)) .net4(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, false))
.inNCHW(inNCHW) .inNCHW(inNCHW)
.labelsNCHW(labels) .labelsNCHW(labels)
.labelsNHWC(labels) .labelsNHWC(labels)
@ -562,9 +589,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new ConvolutionLayer.Builder() return getNetWithLayer(dataType,new ConvolutionLayer.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -573,7 +600,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new ConvolutionLayer.Builder() return getNetWithLayer(dataType,new ConvolutionLayer.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -583,16 +610,16 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getSubsampling2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new SubsamplingLayer.Builder() return getNetWithLayer(dataType,new SubsamplingLayer.Builder()
.kernelSize(2, 2) .kernelSize(2, 2)
.stride(1, 1) .stride(1, 1)
.dataFormat(format) .dataFormat(format)
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new SubsamplingLayer.Builder() return getNetWithLayer(dataType,new SubsamplingLayer.Builder()
.kernelSize(2, 2) .kernelSize(2, 2)
.stride(1, 1) .stride(1, 1)
.helperAllowFallback(false) .helperAllowFallback(false)
@ -600,9 +627,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getSeparableConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new SeparableConvolution2D.Builder() return getNetWithLayer(dataType,new SeparableConvolution2D.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -611,7 +638,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new SeparableConvolution2D.Builder() return getNetWithLayer(dataType,new SeparableConvolution2D.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -621,9 +648,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getDepthwiseConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new DepthwiseConvolution2D.Builder() return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder()
.depthMultiplier(2) .depthMultiplier(2)
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
@ -633,7 +660,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new DepthwiseConvolution2D.Builder() return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder()
.depthMultiplier(2) .depthMultiplier(2)
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
@ -644,59 +671,59 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getLrnLayer(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new LocalResponseNormalization.Builder() return getNetWithLayer(dataType,new LocalResponseNormalization.Builder()
.dataFormat(format) .dataFormat(format)
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new LocalResponseNormalization.Builder() return getNetWithLayer(dataType,new LocalResponseNormalization.Builder()
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
} }
} }
private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getZeroPaddingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2) return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null); .dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(), return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2).build(),
format, ConvolutionMode.Same, null); format, ConvolutionMode.Same, null);
} }
} }
private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getCropping2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new Cropping2D.Builder(2,2) return getNetWithLayer(dataType,new Cropping2D.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null); .dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new Cropping2D.Builder(2,2) return getNetWithLayer(dataType,new Cropping2D.Builder(2,2)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} }
} }
private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getUpsamplingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new Upsampling2D.Builder(2) return getNetWithLayer(dataType,new Upsampling2D.Builder(2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null); .dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new Upsampling2D.Builder(2) return getNetWithLayer(dataType,new Upsampling2D.Builder(2)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} }
} }
private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getDeconv2DNet2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH) .activation(Activation.TANH)
.kernelSize(2,2) .kernelSize(2,2)
.dataFormat(format) .dataFormat(format)
.stride(2,2) .stride(2,2)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH) .activation(Activation.TANH)
.kernelSize(2,2) .kernelSize(2,2)
.dataFormat(format) .dataFormat(format)
@ -705,50 +732,50 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getBatchNormNet(DataType dataType,boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new BatchNormalization.Builder() return getNetWithLayer(dataType,new BatchNormalization.Builder()
.useLogStd(logStdev) .useLogStd(logStdev)
.dataFormat(format) .dataFormat(format)
.helperAllowFallback(false) .helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null); .nOut(3).build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new BatchNormalization.Builder() return getNetWithLayer(dataType,new BatchNormalization.Builder()
.useLogStd(logStdev) .useLogStd(logStdev)
.helperAllowFallback(false) .helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null); .nOut(3).build(), format, ConvolutionMode.Same, null);
} }
} }
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getSpaceToDepthNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToDepthLayer.Builder() return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder()
.blocks(2) .blocks(2)
.dataFormat(format) .dataFormat(format)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new SpaceToDepthLayer.Builder() return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder()
.blocks(2) .blocks(2)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} }
} }
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getSpaceToBatchNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToBatchLayer.Builder() return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder()
.blocks(2, 2) .blocks(2, 2)
.dataFormat(format) .dataFormat(format)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
} else { } else {
return getNetWithLayer(new SpaceToBatchLayer.Builder() return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder()
.blocks(2, 2) .blocks(2, 2)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
} }
} }
private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getLocallyConnectedNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new LocallyConnected2D.Builder() return getNetWithLayer(dataType,new LocallyConnected2D.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -756,7 +783,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.nOut(3) .nOut(3)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(new LocallyConnected2D.Builder() return getNetWithLayer(dataType,new LocallyConnected2D.Builder()
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
@ -765,9 +792,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) { private MultiLayerNetwork getNetWithLayer(DataType dataType,Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.dataType(this.dataType) .dataType(dataType)
.seed(12345) .seed(12345)
.convolutionMode(cm) .convolutionMode(cm)
.list() .list()
@ -794,13 +821,13 @@ public class ConvDataFormatTests extends BaseDL4JTest {
return net; return net;
} }
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) { private MultiLayerNetwork getGlobalPoolingNet(DataType dataType,CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt)
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2}) .poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} }
} }

View File

@ -45,8 +45,11 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.TimeSeriesUtils; import org.deeplearning4j.util.TimeSeriesUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -61,30 +64,29 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.conf.RNNFormat.NCW; import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
@DisplayName("Bidirectional Test") @DisplayName("Bidirectional Test")
class BidirectionalTest extends BaseDL4JTest { class BidirectionalTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public BidirectionalTest(RNNFormat rnnDataFormat) {
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters public static Stream<Arguments> params() {
public static Object[] params() { return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
return RNNFormat.values();
} }
@Test @Test
@DisplayName("Compare Implementations") @DisplayName("Compare Implementations")
void compareImplementations() { @ParameterizedTest
@MethodSource("#params")
void compareImplementations(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
// Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params // Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params
@ -147,9 +149,11 @@ class BidirectionalTest extends BaseDL4JTest {
} }
} }
@Test
@DisplayName("Compare Implementations Comp Graph") @DisplayName("Compare Implementations Comp Graph")
void compareImplementationsCompGraph() { @Test
@ParameterizedTest
@MethodSource("#params")
void compareImplementationsCompGraph(RNNFormat rnnFormat) {
// for(WorkspaceMode wsm : WorkspaceMode.values()) { // for(WorkspaceMode wsm : WorkspaceMode.values()) {
for (WorkspaceMode wsm : new WorkspaceMode[] { WorkspaceMode.NONE, WorkspaceMode.ENABLED }) { for (WorkspaceMode wsm : new WorkspaceMode[] { WorkspaceMode.NONE, WorkspaceMode.ENABLED }) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
@ -187,8 +191,8 @@ class BidirectionalTest extends BaseDL4JTest {
Gradient g2 = net2.gradient(); Gradient g2 = net2.gradient();
assertEquals(g1.gradient(), g2.gradient()); assertEquals(g1.gradient(), g2.gradient());
// Ensure updates are equal: // Ensure updates are equal:
ComputationGraphUpdater u1 = (ComputationGraphUpdater) net1.getUpdater(); ComputationGraphUpdater u1 = net1.getUpdater();
ComputationGraphUpdater u2 = (ComputationGraphUpdater) net2.getUpdater(); ComputationGraphUpdater u2 = net2.getUpdater();
assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray()); assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray());
u1.update(g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); u1.update(g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
u2.update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); u2.update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
@ -205,7 +209,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Serialization") @DisplayName("Test Serialization")
void testSerialization() throws Exception { @ParameterizedTest
@MethodSource("#params")
void testSerialization(RNNFormat rnnDataFormat) throws Exception {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -242,7 +248,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Serialization Comp Graph") @DisplayName("Test Serialization Comp Graph")
void testSerializationCompGraph() throws Exception { @ParameterizedTest
@MethodSource("#params")
void testSerializationCompGraph(RNNFormat rnnDataFormat) throws Exception {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -277,7 +285,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Simple Bidirectional") @DisplayName("Test Simple Bidirectional")
void testSimpleBidirectional() { @ParameterizedTest
@MethodSource("#params")
public void testSimpleBidirectional(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -362,7 +372,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Simple Bidirectional Comp Graph") @DisplayName("Test Simple Bidirectional Comp Graph")
void testSimpleBidirectionalCompGraph() { @ParameterizedTest
@MethodSource("#params")
void testSimpleBidirectionalCompGraph(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm); log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -19,7 +19,6 @@
*/ */
package org.deeplearning4j.nn.layers.recurrent; package org.deeplearning4j.nn.layers.recurrent;
import junit.framework.TestCase;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -34,9 +33,12 @@ import org.deeplearning4j.nn.params.GravesBidirectionalLSTMParamInitializer;
import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -44,31 +46,29 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.primitives.Pair;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class) import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Graves Bidirectional LSTM Test") @DisplayName("Graves Bidirectional LSTM Test")
class GravesBidirectionalLSTMTest extends BaseDL4JTest { class GravesBidirectionalLSTMTest extends BaseDL4JTest {
private double score = 0.0; private double score = 0.0;
private RNNFormat rnnDataFormat;
public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat) {
this.rnnDataFormat = rnnDataFormat; public static Stream<Arguments> params(){
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
} }
@Parameterized.Parameters
public static Object[] params() {
return RNNFormat.values();
}
@Test @Test
@DisplayName("Test Bidirectional LSTM Graves Forward Basic") @DisplayName("Test Bidirectional LSTM Graves Forward Basic")
void testBidirectionalLSTMGravesForwardBasic() { @MethodSource("#params")
@ParameterizedTest
void testBidirectionalLSTMGravesForwardBasic(RNNFormat rnnDataFormat) {
// Very basic test of forward prop. of LSTM layer with a time series. // Very basic test of forward prop. of LSTM layer with a time series.
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
int nIn = 13; int nIn = 13;
@ -110,19 +110,21 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Bidirectional LSTM Graves Backward Basic") @DisplayName("Test Bidirectional LSTM Graves Backward Basic")
void testBidirectionalLSTMGravesBackwardBasic() { @MethodSource("#params")
@ParameterizedTest
void testBidirectionalLSTMGravesBackwardBasic(RNNFormat rnnDataFormat) {
// Very basic test of backprop for mini-batch + time series // Very basic test of backprop for mini-batch + time series
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
testGravesBackwardBasicHelper(13, 3, 17, 10, 7); testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 7);
// Edge case: miniBatchSize = 1 // Edge case: miniBatchSize = 1
testGravesBackwardBasicHelper(13, 3, 17, 1, 7); testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 7);
// Edge case: timeSeriesLength = 1 // Edge case: timeSeriesLength = 1
testGravesBackwardBasicHelper(13, 3, 17, 10, 1); testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 1);
// Edge case: both miniBatchSize = 1 and timeSeriesLength = 1 // Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
testGravesBackwardBasicHelper(13, 3, 17, 1, 1); testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 1);
} }
private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) { private void testGravesBackwardBasicHelper(RNNFormat rnnDataFormat,int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) {
INDArray inputData = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.ones(miniBatchSize, nIn, timeSeriesLength) : Nd4j.ones(miniBatchSize, timeSeriesLength, nIn); INDArray inputData = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.ones(miniBatchSize, nIn, timeSeriesLength) : Nd4j.ones(miniBatchSize, timeSeriesLength, nIn);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build(); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build();
long numParams = conf.getLayer().initializer().numParams(conf); long numParams = conf.getLayer().initializer().numParams(conf);
@ -204,7 +206,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Get Set Parmas") @DisplayName("Test Get Set Parmas")
void testGetSetParmas() { @MethodSource("#params")
@ParameterizedTest
void testGetSetParmas(RNNFormat rnnDataFormat) {
final int nIn = 2; final int nIn = 2;
final int layerSize = 3; final int layerSize = 3;
final int miniBatchSize = 2; final int miniBatchSize = 2;
@ -224,7 +228,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Simple Forwards And Backwards Activation") @DisplayName("Test Simple Forwards And Backwards Activation")
void testSimpleForwardsAndBackwardsActivation() { @MethodSource("#params")
@ParameterizedTest
void testSimpleForwardsAndBackwardsActivation(RNNFormat rnnDataFormat) {
final int nIn = 2; final int nIn = 2;
final int layerSize = 3; final int layerSize = 3;
final int miniBatchSize = 1; final int miniBatchSize = 1;
@ -342,7 +348,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test @Test
@DisplayName("Test Gate Activation Fns Sanity Check") @DisplayName("Test Gate Activation Fns Sanity Check")
void testGateActivationFnsSanityCheck() { @MethodSource("#params")
@ParameterizedTest
void testGateActivationFnsSanityCheck(RNNFormat rnnDataFormat) {
for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) { for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);

View File

@ -30,36 +30,35 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class)
@DisplayName("Mask Zero Layer Test") @DisplayName("Mask Zero Layer Test")
class MaskZeroLayerTest extends BaseDL4JTest { class MaskZeroLayerTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public MaskZeroLayerTest(RNNFormat rnnDataFormat) { public static Stream<Arguments> params() {
this.rnnDataFormat = rnnDataFormat; return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
} }
@Parameterized.Parameters
public static Object[] params() {
return RNNFormat.values();
}
@Test
@DisplayName("Activate") @DisplayName("Activate")
void activate() { @Test
@ParameterizedTest
@MethodSource("#params")
void activate(RNNFormat rnnDataFormat) {
// GIVEN two examples where some of the timesteps are zero. // GIVEN two examples where some of the timesteps are zero.
INDArray ex1 = Nd4j.create(new double[][] { new double[] { 0, 3, 5 }, new double[] { 0, 0, 2 } }); INDArray ex1 = Nd4j.create(new double[][] { new double[] { 0, 3, 5 }, new double[] { 0, 0, 2 } });
INDArray ex2 = Nd4j.create(new double[][] { new double[] { 0, 0, 2 }, new double[] { 0, 0, 2 } }); INDArray ex2 = Nd4j.create(new double[][] { new double[] { 0, 0, 2 }, new double[] { 0, 0, 2 } });
@ -95,9 +94,12 @@ class MaskZeroLayerTest extends BaseDL4JTest {
assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6); assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6);
} }
@Test
@DisplayName("Test Serialization") @DisplayName("Test Serialization")
void testSerialization() { @Test
@ParameterizedTest
@MethodSource("#params")
void testSerialization(RNNFormat rnnDataFormat) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder().setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()).build(); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder().setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();

View File

@ -40,8 +40,10 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -51,30 +53,31 @@ import org.nd4j.common.primitives.Pair;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@AllArgsConstructor @AllArgsConstructor
public class RnnDataFormatTests extends BaseDL4JTest { public class RnnDataFormatTests extends BaseDL4JTest {
private boolean helpers;
private boolean lastTimeStep;
private boolean maskZeros;
@Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}") public static Stream<Arguments> params() {
public static List params(){
List<Object[]> ret = new ArrayList<>(); List<Object[]> ret = new ArrayList<>();
for (boolean helpers: new boolean[]{true, false}) for (boolean helpers: new boolean[]{true, false})
for (boolean lastTimeStep: new boolean[]{true, false}) for (boolean lastTimeStep: new boolean[]{true, false})
for (boolean maskZero: new boolean[]{true, false}) for (boolean maskZero: new boolean[]{true, false})
ret.add(new Object[]{helpers, lastTimeStep, maskZero}); ret.add(new Object[]{helpers, lastTimeStep, maskZero});
return ret; return ret.stream().map(Arguments::of);
} }
@Test @Test
public void testSimpleRnn() { @MethodSource("#params")
@ParameterizedTest
public void testSimpleRnn(boolean helpers,
boolean lastTimeStep,
boolean maskZeros
) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -107,7 +110,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
} }
@Test @Test
public void testLSTM() { @ParameterizedTest
@MethodSource("#params")
public void testLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -141,7 +148,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
@Test @Test
public void testGraveLSTM() { @MethodSource("#params")
@ParameterizedTest
public void testGraveLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -175,7 +186,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
@Test @Test
public void testGraveBiLSTM() { @MethodSource("#params")
@ParameterizedTest
public void testGraveBiLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try { try {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -34,14 +34,20 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.AdaGrad;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
import static org.deeplearning4j.nn.weights.WeightInit.XAVIER_UNIFORM; import static org.deeplearning4j.nn.weights.WeightInit.XAVIER_UNIFORM;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@ -50,20 +56,16 @@ import static org.nd4j.linalg.activations.Activation.TANH;
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE; import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
@RunWith(Parameterized.class)
public class TestLastTimeStepLayer extends BaseDL4JTest { public class TestLastTimeStepLayer extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestLastTimeStepLayer(RNNFormat rnnDataFormat){ public static Stream<Arguments> params(){
this.rnnDataFormat = rnnDataFormat; return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters(name="{0}")
public static Object[] params(){
return RNNFormat.values();
} }
@Test @Test
public void testLastTimeStepVertex() { @ParameterizedTest
@MethodSource("#params")
public void testLastTimeStepVertex(RNNFormat rnnDataFormat) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder() .addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
@ -126,7 +128,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
} }
@Test @Test
public void testMaskingAndAllMasked(){ @ParameterizedTest
@MethodSource("#params")
public void testMaskingAndAllMasked(RNNFormat rnnDataFormat) {
ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder()
.optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT) .optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT)
.weightInit(XAVIER_UNIFORM) .weightInit(XAVIER_UNIFORM)

View File

@ -36,8 +36,11 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -49,25 +52,23 @@ import org.nd4j.common.primitives.Pair;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class TestRnnLayers extends BaseDL4JTest { public class TestRnnLayers extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestRnnLayers(RNNFormat rnnDataFormat){ public static Stream<Arguments> params(){
this.rnnDataFormat = rnnDataFormat; return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
} }
@Test @Test
public void testTimeStepIs3Dimensional() { @ParameterizedTest
@MethodSource("#params")
public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat) {
int nIn = 12; int nIn = 12;
int nOut = 3; int nOut = 3;
@ -117,7 +118,9 @@ public class TestRnnLayers extends BaseDL4JTest {
} }
@Test @Test
public void testDropoutRecurrentLayers(){ @ParameterizedTest
@MethodSource("#params")
public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
String[] layerTypes = new String[]{"graves", "lstm", "simple"}; String[] layerTypes = new String[]{"graves", "lstm", "simple"};
@ -215,9 +218,11 @@ public class TestRnnLayers extends BaseDL4JTest {
} }
@Test @Test
public void testMismatchedInputLabelLength(){ @ParameterizedTest
@MethodSource("#params")
public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat){
for( int i=0; i<2; i++ ){ for( int i = 0; i < 2; i++) {
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder() NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()

View File

@ -29,8 +29,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,25 +40,25 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point; import static org.nd4j.linalg.indexing.NDArrayIndex.point;
@RunWith(Parameterized.class)
public class TestSimpleRnn extends BaseDL4JTest { public class TestSimpleRnn extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestSimpleRnn(RNNFormat rnnDataFormat){ public static Stream<Arguments> params() {
this.rnnDataFormat = rnnDataFormat; return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
} }
@Test @Test
public void testSimpleRnn(){ @ParameterizedTest
@MethodSource("#params")
public void testSimpleRnn(RNNFormat rnnDataFormat) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int m = 3; int m = 3;
@ -125,7 +127,9 @@ public class TestSimpleRnn extends BaseDL4JTest {
} }
@Test @Test
public void testBiasInit(){ @ParameterizedTest
@MethodSource("#params")
public void testBiasInit(RNNFormat rnnDataFormat) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 5; int nIn = 5;
int layerSize = 6; int layerSize = 6;

View File

@ -37,8 +37,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -47,22 +49,22 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestTimeDistributed extends BaseDL4JTest { public class TestTimeDistributed extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestTimeDistributed(RNNFormat rnnDataFormat){ public static Stream<Arguments> params(){
this.rnnDataFormat = rnnDataFormat; return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
} }
@Test @Test
public void testTimeDistributed(){ @ParameterizedTest
@MethodSource("#params")
public void testTimeDistributed(RNNFormat rnnDataFormat){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) { for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -133,10 +135,12 @@ public class TestTimeDistributed extends BaseDL4JTest {
@Test @Test
public void testTimeDistributedDense(){ @MethodSource("#params")
@ParameterizedTest
public void testTimeDistributedDense(RNNFormat rnnDataFormat){
for( int rnnType=0; rnnType<3; rnnType++ ) { for( int rnnType = 0; rnnType < 3; rnnType++ ) {
for( int ffType=0; ffType<3; ffType++ ) { for( int ffType = 0; ffType < 3; ffType++ ) {
Layer l0, l2; Layer l0, l2;
switch (rnnType) { switch (rnnType) {

View File

@ -39,8 +39,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -145,92 +145,6 @@
</execution> </execution>
</executions> </executions>
</plugin> </plugin>
<plugin>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-plugin</artifactId>
<version>1.4.30-M1</version>
<configuration>
<args>
<arg>-Xjsr305=strict</arg>
</args>
<compilerPlugins>
<plugin>spring</plugin>
<plugin>jpa</plugin>
</compilerPlugins>
</configuration>
<dependencies>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-allopen</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-noarg</artifactId>
<version>${kotlin.version}</version>
</dependency>
</dependencies>
<executions>
<execution>
<id>compile</id>
<goals> <goal>compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/main/stubs</sourceDir>
<sourceDir>${project.basedir}/src/main/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/main/java</sourceDir>
<sourceDir>${project.basedir}/src/main/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
<execution>
<id>test-compile</id>
<goals> <goal>test-compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/test/stubs</sourceDir>
<sourceDir>${project.basedir}/src/test/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/test/java</sourceDir>
<sourceDir>${project.basedir}/src/test/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
</executions>
</plugin>
<!-- https://kotlinlang.org/docs/reference/using-maven.html -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.5.1</version>
<executions>
<!-- Replacing default-compile as it is treated specially by maven -->
<execution>
<id>default-compile</id>
<phase>none</phase>
</execution>
<!-- Replacing default-testCompile as it is treated specially by maven -->
<execution>
<id>default-testCompile</id>
<phase>none</phase>
</execution>
<execution>
<id>java-compile</id>
<phase>compile</phase>
<goals> <goal>compile</goal> </goals>
</execution>
<execution>
<id>java-test-compile</id>
<phase>test-compile</phase>
<goals> <goal>testCompile</goal> </goals>
</execution>
</executions>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins> </plugins>
</build> </build>
@ -244,7 +158,10 @@
<groupId>org.junit.jupiter</groupId> <groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId> <artifactId>junit-jupiter-engine</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.jetbrains.kotlin</groupId> <groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib-jdk8</artifactId> <artifactId>kotlin-stdlib-jdk8</artifactId>
@ -261,11 +178,14 @@
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>samediff-import-tensorflow</artifactId> <artifactId>samediff-import-tensorflow</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>compile</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>samediff-import-onnx</artifactId> <artifactId>samediff-import-onnx</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>compile</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>

View File

@ -22,11 +22,6 @@ package org.nd4j;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass; import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.imports.tfgraphs.TFGraphTestAllLibnd4j;
import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
import org.nd4j.imports.tfgraphs.TFGraphTestList;
import org.nd4j.imports.tfgraphs.TFGraphTestZooModels;
import org.nd4j.imports.listeners.ImportModelDebugger;
import java.util.*; import java.util.*;
@Slf4j @Slf4j
@ -36,11 +31,6 @@ public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
protected Set<Class<?>> getExclusions() { protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>(Arrays.asList( return new HashSet<>(Arrays.asList(
TFGraphTestAllSameDiff.class,
TFGraphTestAllLibnd4j.class,
TFGraphTestList.class,
TFGraphTestZooModels.class,
ImportModelDebugger.class //Run manually only, otherwise ignored
)); ));
} }

View File

@ -20,19 +20,16 @@
package org.nd4j; package org.nd4j;
import org.bytedeco.javacpp.Loader;
import org.junit.AfterClass; import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Suite; import org.junit.runners.Suite;
import org.nd4j.autodiff.opvalidation.*; import org.nd4j.autodiff.opvalidation.*;
import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff; //import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.function.Function;
import static org.junit.Assume.assumeFalse; import static org.junit.Assume.assumeFalse;
@ -49,7 +46,7 @@ import static org.junit.Assume.assumeFalse;
TransformOpValidation.class, TransformOpValidation.class,
//TF import tests //TF import tests
TFGraphTestAllSameDiff.class //TFGraphTestAllSameDiff.class
//TFGraphTestAllLibnd4j.class //TFGraphTestAllLibnd4j.class
}) })
//IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test" //IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test"

View File

@ -27,10 +27,12 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.ImportClassMapping; import org.nd4j.imports.converters.ImportClassMapping;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.NoOp; import org.nd4j.linalg.api.ops.NoOp;
import org.nd4j.linalg.api.ops.compat.CompatSparseToDense; import org.nd4j.linalg.api.ops.compat.CompatSparseToDense;
@ -122,13 +124,11 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
@Disabled("No longer relevant after model import rewrite.") @Disabled("No longer relevant after model import rewrite.")
public class TestOpMapping extends BaseNd4jTest { public class TestOpMapping extends BaseNd4jTestWithBackends {
Set<Class<? extends DifferentialFunction>> subTypes; Set<Class<? extends DifferentialFunction>> subTypes;
public TestOpMapping(Nd4jBackend b){ public TestOpMapping() {
super(b);
Reflections reflections = new Reflections("org.nd4j"); Reflections reflections = new Reflections("org.nd4j");
subTypes = reflections.getSubTypesOf(DifferentialFunction.class); subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
} }
@ -146,6 +146,8 @@ public class TestOpMapping extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpMappingCoverage() throws Exception { public void testOpMappingCoverage() throws Exception {
Map<String, DifferentialFunction> opNameMapping = ImportClassMapping.getOpNameMapping(); Map<String, DifferentialFunction> opNameMapping = ImportClassMapping.getOpNameMapping();
Map<String, DifferentialFunction> tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions(); Map<String, DifferentialFunction> tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions();
@ -196,7 +198,9 @@ public class TestOpMapping extends BaseNd4jTest {
} }
@Test @Test
public void testOpsInNamespace() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpsInNamespace(Nd4jBackend backend) throws Exception {
//Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't //Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't
// want to add to a namespace for some reason) // want to add to a namespace for some reason)
//Note that we ignore "*Bp", "*Gradient", "*Derivative" etc ops //Note that we ignore "*Bp", "*Gradient", "*Derivative" etc ops
@ -354,8 +358,11 @@ public class TestOpMapping extends BaseNd4jTest {
s.add(Assign.class); s.add(Assign.class);
} }
@Test @Disabled @Test
public void generateOpClassList() throws Exception{ @Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void generateOpClassList(Nd4jBackend backend) throws Exception{
Reflections reflections = new Reflections("org.nd4j"); Reflections reflections = new Reflections("org.nd4j");
Set<Class<? extends DifferentialFunction>> subTypes = reflections.getSubTypesOf(DifferentialFunction.class); Set<Class<? extends DifferentialFunction>> subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
@ -366,12 +373,7 @@ public class TestOpMapping extends BaseNd4jTest {
l.add(c); l.add(c);
} }
Collections.sort(l, new Comparator<Class<?>>() { Collections.sort(l, Comparator.comparing(Class::getName));
@Override
public int compare(Class<?> o1, Class<?> o2) {
return o1.getName().compareTo(o2.getName());
}
});
for(Class<?> c : l){ for(Class<?> c : l){
System.out.println(c.getName() + ".class,"); System.out.println(c.getName() + ".class,");

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -31,7 +33,7 @@ import org.nd4j.autodiff.samediff.internal.FrameIter;
import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.memory.NoOpMemoryMgr; import org.nd4j.autodiff.samediff.internal.memory.NoOpMemoryMgr;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -46,19 +48,17 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class TestSessions extends BaseNd4jTest { public class TestSessions extends BaseNd4jTestWithBackends {
public TestSessions(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering() {
return 'c'; return 'c';
} }
@Test @Test
public void testInferenceSessionBasic(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInferenceSessionBasic(Nd4jBackend backend) {
//So far: trivial test to check execution order //So far: trivial test to check execution order
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -90,7 +90,9 @@ public class TestSessions extends BaseNd4jTest {
@Test @Test
public void testInferenceSessionBasic2(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInferenceSessionBasic2(Nd4jBackend backend) {
//So far: trivial test to check execution order //So far: trivial test to check execution order
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -126,7 +128,9 @@ public class TestSessions extends BaseNd4jTest {
} }
@Test @Test
public void testMergeSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeSimple(Nd4jBackend backend) {
//This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available... //This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available...
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -162,7 +166,9 @@ public class TestSessions extends BaseNd4jTest {
@Test @Test
public void testSwitchSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSwitchSimple(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3,3); SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3,3);

View File

@ -21,10 +21,12 @@
package org.nd4j.autodiff.internal; package org.nd4j.autodiff.internal;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.DependencyList; import org.nd4j.autodiff.samediff.internal.DependencyList;
import org.nd4j.autodiff.samediff.internal.DependencyTracker; import org.nd4j.autodiff.samediff.internal.DependencyTracker;
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker; import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,19 +37,18 @@ import java.util.Collections;
import static junit.framework.TestCase.assertNotNull; import static junit.framework.TestCase.assertNotNull;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class TestDependencyTracker extends BaseNd4jTest { public class TestDependencyTracker extends BaseNd4jTestWithBackends {
public TestDependencyTracker(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';
} }
@Test @Test
public void testSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>(); DependencyTracker<String,String> dt = new DependencyTracker<>();
@ -93,8 +94,10 @@ public class TestDependencyTracker extends BaseNd4jTest {
assertTrue(dt.isEmpty()); assertTrue(dt.isEmpty());
} }
@Test @Test
public void testSatisfiedBeforeAdd(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSatisfiedBeforeAdd(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>(); DependencyTracker<String,String> dt = new DependencyTracker<>();
//Check different order of adding dependencies: i.e., mark X as satisfied, then add x -> y dependency //Check different order of adding dependencies: i.e., mark X as satisfied, then add x -> y dependency
@ -132,8 +135,10 @@ public class TestDependencyTracker extends BaseNd4jTest {
assertFalse(dt.hasNewAllSatisfied()); assertFalse(dt.hasNewAllSatisfied());
} }
@Test @Test
public void testMarkUnsatisfied(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMarkUnsatisfied(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>(); DependencyTracker<String,String> dt = new DependencyTracker<>();
dt.addDependency("y", "x"); dt.addDependency("y", "x");
@ -164,7 +169,9 @@ public class TestDependencyTracker extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIdentityDependencyTracker(){ public void testIdentityDependencyTracker(){
IdentityDependencyTracker<INDArray, String> dt = new IdentityDependencyTracker<>(); IdentityDependencyTracker<INDArray, String> dt = new IdentityDependencyTracker<>();
assertTrue(dt.isEmpty()); assertTrue(dt.isEmpty());

View File

@ -21,6 +21,8 @@
package org.nd4j.autodiff.opvalidation; package org.nd4j.autodiff.opvalidation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.GradCheckUtil; import org.nd4j.autodiff.validation.GradCheckUtil;
@ -38,12 +40,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class ActivationGradChecks extends BaseOpValidation { public class ActivationGradChecks extends BaseOpValidation {
public ActivationGradChecks(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testActivationGradientCheck1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testActivationGradientCheck1(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4)); SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4));
@ -61,7 +62,9 @@ public class ActivationGradChecks extends BaseOpValidation {
} }
@Test @Test
public void testActivationGradientCheck2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testActivationGradientCheck2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4); SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4);

View File

@ -21,18 +21,14 @@
package org.nd4j.autodiff.opvalidation; package org.nd4j.autodiff.opvalidation;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
public abstract class BaseOpValidation extends BaseNd4jTest { public abstract class BaseOpValidation extends BaseNd4jTestWithBackends {
private DataType initialType; private DataType initialType = Nd4j.dataType();
public BaseOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {

View File

@ -27,6 +27,8 @@ import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.OpValidation;
@ -65,9 +67,6 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class LayerOpValidation extends BaseOpValidation { public class LayerOpValidation extends BaseOpValidation {
public LayerOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
@ -75,7 +74,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testXwPlusB() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testXwPlusB(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -109,7 +110,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReluLayer() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReluLayer(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -137,7 +140,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBiasAdd() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -161,7 +166,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testConv2d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2d(Nd4jBackend backend) {
//avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling //avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling
//Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d //Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d
@ -301,7 +308,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLrn2d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLrn2d(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}}; int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};
@ -342,7 +351,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testIm2Col() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIm2Col(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873 //OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -381,7 +392,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testOutputShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOutputShape(Nd4jBackend backend) {
long[] inSize = {1, 8, 8, 3}; long[] inSize = {1, 8, 8, 3};
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -431,7 +444,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testAvgPool() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPool(Nd4jBackend backend) {
long[] inSize = {1, 8, 8, 3}; //NHWC long[] inSize = {1, 8, 8, 3}; //NHWC
Pooling2DConfig conf = Pooling2DConfig.builder() Pooling2DConfig conf = Pooling2DConfig.builder()
@ -474,7 +489,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testConv3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3d(Nd4jBackend backend) {
//Pooling3d, Conv3D, batch norm //Pooling3d, Conv3D, batch norm
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -576,7 +593,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testDepthWiseConv2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthWiseConv2dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int depthWise = 4; int depthWise = 4;
int kH = 2; int kH = 2;
@ -615,7 +634,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSeparableConv2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSeparableConv2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 2; int nIn = 2;
int nOut = 3; int nOut = 3;
@ -671,7 +692,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDeconv2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeconv2dBasic(Nd4jBackend backend) {
int nIn = 2; int nIn = 2;
int nOut = 3; int nOut = 3;
int kH = 2; int kH = 2;
@ -715,7 +738,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testConv2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
int kH = 2; int kH = 2;
@ -756,7 +781,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxPoolingArgMax() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPoolingArgMax(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
@ -785,7 +812,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxPooling2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPooling2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
@ -843,7 +872,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testAvgPooling2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPooling2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
@ -892,7 +923,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testAvgPooling3dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPooling3dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
int kW = 2; int kW = 2;
@ -929,7 +962,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxPooling3dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPooling3dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int kH = 2; int kH = 2;
int kW = 2; int kW = 2;
@ -967,7 +1002,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testConv1dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
int k = 2; int k = 2;
@ -1002,7 +1039,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testConv1dCausal() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dCausal(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
@ -1051,7 +1090,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testConv1dForward() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dForward(Nd4jBackend backend) {
int nIn = 2; int nIn = 2;
int nOut = 1; int nOut = 1;
int kernel = 3; int kernel = 3;
@ -1094,7 +1135,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testConv3dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3dBasic(Nd4jBackend backend) {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
int kH = 2; int kH = 2;
@ -1140,7 +1183,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDeConv3dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv3dBasic(Nd4jBackend backend) {
int nIn = 4; int nIn = 4;
int nOut = 3; int nOut = 3;
int kH = 2; int kH = 2;
@ -1185,7 +1230,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNorm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNorm(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1210,7 +1257,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNorm4d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNorm4d(Nd4jBackend backend) {
int mb = 3; int mb = 3;
int ch = 4; int ch = 4;
for (boolean nchw : new boolean[]{true, false}) { for (boolean nchw : new boolean[]{true, false}) {
@ -1242,7 +1291,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void testLayerNormOP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormOP(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1258,7 +1309,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNormNoBias() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormNoBias(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1281,7 +1334,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNormOPNoBias() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormOPNoBias(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike(); final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1296,7 +1351,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLayerNormNoDeviation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormNoDeviation(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
random.putScalar(1, i, 7); random.putScalar(1, i, 7);
@ -1326,36 +1383,36 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test() @Test()
public void exceptionThrown_WhenConv1DConfigInvalid() { public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
int nIn = 3; int nIn = 3;
int nOut = 4; int nOut = 4;
int k = 2; int k = 2;
int mb = 3; int mb = 3;
int img = 28; int img = 28;
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray wArr = Nd4j.create(k, nIn, nOut); INDArray wArr = Nd4j.create(k, nIn, nOut);
INDArray inArr = Nd4j.create(mb, nIn, img); INDArray inArr = Nd4j.create(mb, nIn, img);
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
SDVariable w = sd.var("W", wArr); SDVariable w = sd.var("W", wArr);
SDVariable[] vars = new SDVariable[]{in, w}; SDVariable[] vars = new SDVariable[]{in, w};
Conv1DConfig conv1DConfig = Conv1DConfig.builder() Conv1DConfig conv1DConfig = Conv1DConfig.builder()
.k(k).p(-1).s(0) .k(k).p(-1).s(0)
.paddingMode(PaddingMode.VALID) .paddingMode(PaddingMode.VALID)
.build(); .build();
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
}); });
} }
@Test() @Test()
public void exceptionThrown_WhenConv2DConfigInvalid() { public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1378,40 +1435,42 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test() @Test()
public void exceptionThrown_WhenConf3DInvalid() { public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
//NCDHW format //NCDHW format
int[] inSizeNCDHW = {2, 3, 4, 5, 5}; int[] inSizeNCDHW = {2, 3, 4, 5, 5};
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
for (boolean ncdhw : new boolean[]{true, false}) { for (boolean ncdhw : new boolean[]{true, false}) {
int nIn = inSizeNCDHW[1]; int nIn = inSizeNCDHW[1];
int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW)); int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW));
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", shape); SDVariable in = sd.var("in", shape);
SDVariable out; SDVariable out;
String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape); String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape);
SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC] SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC]
SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10)); SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10));
out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder() out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder()
.dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC) .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC)
.isSameMode(true) .isSameMode(true)
.kH(2).kW(2).kD(2) .kH(2).kW(2).kD(2)
.sD(1).sH(1).sW(-1).dW(-1) .sD(1).sH(1).sW(-1).dW(-1)
.build()); .build());
} }
}); });
} }
@Test @Test
public void testLayerNormMixedOrders() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormMixedOrders(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f'); INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
@ -1458,7 +1517,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBiasAdd_nchw_nhwc() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd_nchw_nhwc(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for (boolean nchw : new boolean[]{true, false}) { for (boolean nchw : new boolean[]{true, false}) {
@ -1489,6 +1550,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthwiseConv2D(){ public void testDepthwiseConv2D(){
int bS = 10; int bS = 10;
@ -1527,7 +1590,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void LSTMLayerTestCase1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase1(Nd4jBackend backend) {
int bS = 5; int bS = 5;
int nIn = 3; int nIn = 3;
@ -1602,7 +1667,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test @Test
public void LSTMLayerTestCase2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase2(Nd4jBackend backend) {
int bS = 5; int bS = 5;
int nIn = 3; int nIn = 3;
int numUnits = 7; int numUnits = 7;
@ -1660,7 +1727,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void LSTMLayerTestCase3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase3(Nd4jBackend backend) {
int bS = 5; int bS = 5;
int nIn = 3; int nIn = 3;
int numUnits = 7; int numUnits = 7;
@ -1721,7 +1790,9 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test @Test
public void GRUTestCase() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void GRUTestCase(Nd4jBackend backend) {
int bS = 5; int bS = 5;
int nIn = 4; int nIn = 4;
int nOut = 6; int nOut = 6;

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -43,9 +45,7 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class LossOpValidation extends BaseOpValidation { public class LossOpValidation extends BaseOpValidation {
public LossOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
@ -56,7 +56,9 @@ public class LossOpValidation extends BaseOpValidation {
public static final Set<String> NO_BP_YET = new HashSet<>(); public static final Set<String> NO_BP_YET = new HashSet<>();
@Test @Test
public void testLoss2d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoss2d(Nd4jBackend backend) {
final List<String> oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax"); final List<String> oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax");
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -69,7 +71,7 @@ public class LossOpValidation extends BaseOpValidation {
"absdiff", "cosine", "hinge", "huber", "log", "mse", "absdiff", "cosine", "hinge", "huber", "log", "mse",
"sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse", "sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse",
"sparsesoftmax" "sparsesoftmax"
}) { }) {
for(String weights : new String[]{"none", "scalar", "perExample", "perOutput"}) { for(String weights : new String[]{"none", "scalar", "perExample", "perOutput"}) {
@ -368,6 +370,8 @@ public class LossOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCosineDistance(){ public void testCosineDistance(){
INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}}); INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}});
INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}}); INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}});
@ -386,6 +390,8 @@ public class LossOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testL2Loss(){ public void testL2Loss(){
for( int rank=0; rank<=3; rank++ ){ for( int rank=0; rank<=3; rank++ ){
@ -428,7 +434,9 @@ public class LossOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNonZeroResult() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNonZeroResult(Nd4jBackend backend) {
INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5); INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5);
INDArray w = Nd4j.scalar(1.0); INDArray w = Nd4j.scalar(1.0);
INDArray label = Nd4j.rand(DataType.DOUBLE, 10, 5); INDArray label = Nd4j.rand(DataType.DOUBLE, 10, 5);
@ -486,6 +494,8 @@ public class LossOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void TestStdLossMixedDataType(){ public void TestStdLossMixedDataType(){
// Default Data Type in this test suite is Double. // Default Data Type in this test suite is Double.
// This test used to throw an Exception that we have mixed data types. // This test used to throw an Exception that we have mixed data types.

View File

@ -23,6 +23,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -78,13 +80,12 @@ import static org.junit.Assume.assumeNotNull;
@Slf4j @Slf4j
public class MiscOpValidation extends BaseOpValidation { public class MiscOpValidation extends BaseOpValidation {
public MiscOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testGradientAutoBroadcast1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast1(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -171,7 +172,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testGradientAutoBroadcast2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -260,7 +263,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testGradientAutoBroadcast3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast3(Nd4jBackend backend) {
//These tests: output size > input sizes //These tests: output size > input sizes
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -368,7 +373,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testScatterOpGradients() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterOpGradients(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
for (int i = 0; i < 7; i++) { for (int i = 0; i < 7; i++) {
@ -470,6 +477,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterUpdate(){ public void testScatterUpdate(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3); INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3);
INDArray updates = Nd4j.create(new float[][]{ INDArray updates = Nd4j.create(new float[][]{
@ -491,7 +500,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testGatherGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -542,6 +553,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrace(){ public void testTrace(){
//TODO need to work out how to handle shape_op for scalars... //TODO need to work out how to handle shape_op for scalars...
//OpValidationSuite.ignoreFailing(); //OpValidationSuite.ignoreFailing();
@ -567,7 +580,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
public void testTensorGradTensorMmul() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorGradTensorMmul(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -589,7 +604,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMulGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMulGradient(Nd4jBackend backend) {
INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
@ -654,22 +671,21 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
public void testMmulGradientManual() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulGradientManual(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
Map<String, INDArray> inputs = new HashMap<>(); Map<String, INDArray> inputs = new HashMap<>();
inputs.put("x", sumInput); inputs.put("x", sumInput);
inputs.put("y", sumInput.dup()); inputs.put("y", sumInput.dup());
sameDiff.defineFunction("mmulGradient", new SameDiffFunctionDefinition() { sameDiff.defineFunction("mmulGradient", (sameDiff1, inputs1, variableInputs) -> {
@Override SDVariable input = sameDiff1.var("x", inputs1.get("x"));
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) { SDVariable input2 = sameDiff1.var("y", inputs1.get("y"));
SDVariable input = sameDiff.var("x", inputs.get("x")); SDVariable exp = sameDiff1.mmul(input, input2);
SDVariable input2 = sameDiff.var("y", inputs.get("y")); SDVariable sum = sameDiff1.sum(exp, Integer.MAX_VALUE);
SDVariable exp = sameDiff.mmul(input, input2); return new SDVariable[]{sum};
SDVariable sum = sameDiff.sum(exp, Integer.MAX_VALUE);
return new SDVariable[]{sum};
}
}, inputs); }, inputs);
@ -698,6 +714,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulGradients(){ public void testMmulGradients(){
int[] aShape = new int[]{2,3}; int[] aShape = new int[]{2,3};
int[] bShape = new int[]{3,4}; int[] bShape = new int[]{3,4};
@ -749,7 +767,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBatchMmulBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBatchMmulBasic(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873 OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873
int M = 5; int M = 5;
int N = 3; int N = 3;
@ -774,7 +794,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
public void testMmulWithTranspose() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulWithTranspose(Nd4jBackend backend) {
//Here: [x,3]^T * [x,4] = [3,4] //Here: [x,3]^T * [x,4] = [3,4]
@ -811,6 +833,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulOutputSizeCalculation(){ public void testMmulOutputSizeCalculation(){
//[3,2] x [2,4] with result transpose: output shape [4,3] //[3,2] x [2,4] with result transpose: output shape [4,3]
INDArray a = Nd4j.create(3,2); INDArray a = Nd4j.create(3,2);
@ -820,7 +844,7 @@ public class MiscOpValidation extends BaseOpValidation {
.transposeA(false) .transposeA(false)
.transposeB(false) .transposeB(false)
.transposeResult(true) .transposeResult(true)
.build()); .build());
val outShapes = Nd4j.getExecutioner().calculateOutputShape(m); val outShapes = Nd4j.getExecutioner().calculateOutputShape(m);
assertArrayEquals(new long[]{4,3}, outShapes.get(0).getShape()); assertArrayEquals(new long[]{4,3}, outShapes.get(0).getShape());
@ -843,6 +867,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFillOp(){ public void testFillOp(){
INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT); INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT);
@ -857,6 +883,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm(){ public void testClipByNorm(){
//Expected: if array.norm2(1) is less than 1.0, not modified //Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -889,6 +917,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm2(){ public void testClipByNorm2(){
//Expected: if array.norm2(1) is less than 1.0, not modified //Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -932,6 +962,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm1(){ public void testClipByNorm1(){
//Expected: if array.norm2(1) is less than 1.0, not modified //Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -972,6 +1004,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm0(){ public void testClipByNorm0(){
//Expected: if array.norm2(0) is less than 1.0, not modified //Expected: if array.norm2(0) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -1001,6 +1035,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumSum(){ public void testCumSum(){
List<String> failing = new ArrayList<>(); List<String> failing = new ArrayList<>();
@ -1066,6 +1102,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumProd(){ public void testCumProd(){
List<String> failing = new ArrayList<>(); List<String> failing = new ArrayList<>();
@ -1134,6 +1172,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot1(){ public void testOneHot1(){
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -1164,6 +1204,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHotOp(){ public void testOneHotOp(){
//https://www.tensorflow.org/api_docs/python/tf/one_hot //https://www.tensorflow.org/api_docs/python/tf/one_hot
//https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp //https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp
@ -1178,7 +1220,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testOneHot2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot2(Nd4jBackend backend) {
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1); INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
@ -1198,7 +1242,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testOneHot4() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot4(Nd4jBackend backend) {
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1); INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
@ -1218,7 +1264,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testOneHot3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot3(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6872 //https://github.com/deeplearning4j/deeplearning4j/issues/6872
//https://www.tensorflow.org/api_docs/python/tf/one_hot //https://www.tensorflow.org/api_docs/python/tf/one_hot
@ -1253,6 +1301,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLinspace(){ public void testLinspace(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10); SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10);
@ -1266,6 +1316,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLinspace2(){ public void testLinspace2(){
OpValidationSuite.ignoreFailing(); //TODO 2019/01/18 OpValidationSuite.ignoreFailing(); //TODO 2019/01/18
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1280,7 +1332,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
public void testShapeFn() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapeFn(Nd4jBackend backend) {
INDArray in = Nd4j.create(new long[]{1, 2}); INDArray in = Nd4j.create(new long[]{1, 2});
@ -1294,7 +1348,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testShapeFn2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapeFn2(Nd4jBackend backend) {
INDArray i = Nd4j.create(1,3); INDArray i = Nd4j.create(1,3);
@ -1307,6 +1363,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeRank1(){ public void testMergeRank1(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5)); SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5));
@ -1325,7 +1383,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDiagPart() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagPart(Nd4jBackend backend) {
INDArray i = Nd4j.create(5,5); INDArray i = Nd4j.create(5,5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1337,7 +1397,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDiagShapeFn() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagShapeFn(Nd4jBackend backend) {
INDArray i = Nd4j.create(5,5); INDArray i = Nd4j.create(5,5);
CustomOp op = new DiagPart(i, null); CustomOp op = new DiagPart(i, null);
@ -1350,6 +1412,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZerosOnesLike(){ public void testZerosOnesLike(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1392,6 +1456,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZerosLikeOp(){ public void testZerosLikeOp(){
INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0); INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0);
@ -1407,6 +1473,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrix(){ public void testConfusionMatrix(){
DataType dt = DataType.DOUBLE; DataType dt = DataType.DOUBLE;
@ -1443,6 +1511,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIsNonDecreasingIsStrictlyIncr(){ public void testIsNonDecreasingIsStrictlyIncr(){
List<long[]> shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3}); List<long[]> shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3});
@ -1506,6 +1576,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExtractImagePatches(){ public void testExtractImagePatches(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -1553,6 +1625,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentProdBpSimple(){ public void testSegmentProdBpSimple(){
INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT); INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT);
@ -1573,6 +1647,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulRank4() throws Exception { public void testMmulRank4() throws Exception {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1608,6 +1684,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulRank4_simple(){ public void testMmulRank4_simple(){
INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64); INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
@ -1634,6 +1712,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNthElementRank1(){ public void testNthElementRank1(){
INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9}); INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9});
INDArray n = Nd4j.scalar(0); INDArray n = Nd4j.scalar(0);
@ -1656,6 +1736,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorMmulShape(){ public void testTensorMmulShape(){
INDArray a = Nd4j.create(new double[]{2}).reshape(1); INDArray a = Nd4j.create(new double[]{2}).reshape(1);
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2); INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
@ -1674,6 +1756,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorMmulShape2(){ public void testTensorMmulShape2(){
INDArray a = Nd4j.create(new double[]{2}).reshape(1); INDArray a = Nd4j.create(new double[]{2}).reshape(1);
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2); INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
@ -1682,6 +1766,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStopGradient(){ public void testStopGradient(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1701,6 +1787,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckNumerics(){ public void testCheckNumerics(){
OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927 OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927
@ -1744,7 +1832,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testCheckNumerics2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckNumerics2(Nd4jBackend backend) {
INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4); INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4);
INDArray msg = Nd4j.scalar("My error message!"); INDArray msg = Nd4j.scalar("My error message!");
@ -1757,6 +1847,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testHistogramFixedWidth(){ public void testHistogramFixedWidth(){
//Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf] //Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf]
INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9); INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9);
@ -1775,6 +1867,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition(){ public void testDynamicPartition(){
INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
@ -1793,6 +1887,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testListDiff(){ public void testListDiff(){
INDArray x = Nd4j.createFromArray(0, 1, 2, 3); INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
INDArray y = Nd4j.createFromArray(3, 1); INDArray y = Nd4j.createFromArray(3, 1);
@ -1812,7 +1908,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDivideNoNan() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDivideNoNan(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff() OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff()
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1836,7 +1934,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDigamma() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDigamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1851,7 +1951,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testFlatten() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFlatten(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1873,7 +1975,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testFusedBatchNorm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFusedBatchNorm(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1918,7 +2022,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testIgamma() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIgamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1934,7 +2040,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testIgammaC() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIgammaC(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1951,7 +2059,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLgamma() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLgamma(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1976,7 +2086,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLu() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLu(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2007,7 +2119,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMatrixBandPart() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixBandPart(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2037,7 +2151,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testPolygamma() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPolygamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -2053,7 +2169,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTriangularSolve() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriangularSolve(Nd4jBackend backend) {
INDArray a = Nd4j.createFromArray(new float[]{ INDArray a = Nd4j.createFromArray(new float[]{
3.f, 0.f, 0.f, 0.f, 3.f, 0.f, 0.f, 0.f,
@ -2077,7 +2195,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBiasAdd() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2106,7 +2226,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBiasAddGrad() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAddGrad(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -2126,7 +2248,9 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testRoll() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRoll(Nd4jBackend backend) {
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}). 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}).
@ -2146,6 +2270,8 @@ public class MiscOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSeqMask(){ public void testSeqMask(){
INDArray arr = Nd4j.createFromArray(1,2,3); INDArray arr = Nd4j.createFromArray(1,2,3);
INDArray maxLen = Nd4j.scalar(4); INDArray maxLen = Nd4j.scalar(4);

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -51,12 +53,11 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class RandomOpValidation extends BaseOpValidation { public class RandomOpValidation extends BaseOpValidation {
public RandomOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testRandomOpsSDVarShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomOpsSDVarShape(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -157,7 +158,9 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testRandomOpsLongShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomOpsLongShape(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
for (long[] shape : Arrays.asList(new long[]{1000}, new long[]{100, 10}, new long[]{40, 5, 5})) { for (long[] shape : Arrays.asList(new long[]{1000}, new long[]{100, 10}, new long[]{40, 5, 5})) {
@ -283,6 +286,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomBinomial(){ public void testRandomBinomial(){
INDArray z = Nd4j.create(new long[]{10}); INDArray z = Nd4j.create(new long[]{10});
@ -293,7 +298,9 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testUniformRankSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUniformRankSimple(Nd4jBackend backend) {
INDArray arr = Nd4j.createFromArray(new double[]{100.0}); INDArray arr = Nd4j.createFromArray(new double[]{100.0});
// OpTestCase tc = new OpTestCase(DynamicCustomOp.builder("randomuniform") // OpTestCase tc = new OpTestCase(DynamicCustomOp.builder("randomuniform")
@ -325,7 +332,9 @@ public class RandomOpValidation extends BaseOpValidation {
@Test @Test
public void testRandomExponential() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomExponential(Nd4jBackend backend) {
long length = 1_000_000; long length = 1_000_000;
INDArray shape = Nd4j.createFromArray(new double[]{length}); INDArray shape = Nd4j.createFromArray(new double[]{length});
INDArray out = Nd4j.createUninitialized(new long[]{length}); INDArray out = Nd4j.createUninitialized(new long[]{length});
@ -347,6 +356,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRange(){ public void testRange(){
//Technically deterministic, not random... //Technically deterministic, not random...
@ -380,6 +391,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllEmptyReduce(){ public void testAllEmptyReduce(){
INDArray x = Nd4j.createFromArray(true, true, true); INDArray x = Nd4j.createFromArray(true, true, true);
All all = new All(x); All all = new All(x);
@ -389,6 +402,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUniformDtype(){ public void testUniformDtype(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){ for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
@ -417,6 +432,8 @@ public class RandomOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomExponential2(){ public void testRandomExponential2(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
DynamicCustomOp op = DynamicCustomOp.builder("random_exponential") DynamicCustomOp op = DynamicCustomOp.builder("random_exponential")

View File

@ -24,6 +24,8 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.validation.OpTestCase; import org.nd4j.autodiff.validation.OpTestCase;
import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -51,10 +53,6 @@ public class ReductionBpOpValidation extends BaseOpValidation {
private DataType initialType; private DataType initialType;
public ReductionBpOpValidation(Nd4jBackend backend) {
super(backend);
}
@BeforeEach @BeforeEach
public void before() { public void before() {
Nd4j.create(1); Nd4j.create(1);
@ -71,14 +69,16 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@AfterEach @AfterEach
public void tearDown() { public void tearDown(Nd4jBackend backend) {
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
} }
@Test @Test
public void testReduceSumBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumBP(Nd4jBackend backend) {
//Full array reduction //Full array reduction
//reduce_sum_bp op: has 2 inputs (original pre-reduce input, and gradient at output (epsilon)) //reduce_sum_bp op: has 2 inputs (original pre-reduce input, and gradient at output (epsilon))
@ -104,7 +104,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReduceSumAlongDim0BP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumAlongDim0BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -130,7 +132,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReduceSumAlongDim1BP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumAlongDim1BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -158,7 +162,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test @Test
public void testMeanBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanBP(Nd4jBackend backend) {
//dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j)) //dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j))
// = 1/N * dL/dOut // = 1/N * dL/dOut
@ -189,7 +195,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMeanBP_Rank1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanBP_Rank1(Nd4jBackend backend) {
INDArray dLdOut = Nd4j.scalar(0.5); INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5 / 3); INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5 / 3);
@ -202,7 +210,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMeanAlongDim0BP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanAlongDim0BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -230,7 +240,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMeanAlongDim1BP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanAlongDim1BP(Nd4jBackend backend) {
//Reduction along dimension //Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar //Inputs/outputs as before - but note that the output is no longer a scalar
@ -258,7 +270,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test @Test
public void testMinBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMinBP(Nd4jBackend backend) {
//Full array min reduction //Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
@ -297,7 +311,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMinAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMinAlongDimensionBP(Nd4jBackend backend) {
//Full array min reduction //Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
@ -340,7 +356,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxBP(Nd4jBackend backend) {
//Full array max reduction //Full array max reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
@ -370,7 +388,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxAlongDimensionBP(Nd4jBackend backend) {
//Full array min reduction //Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
@ -413,7 +433,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testProdBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProdBP(Nd4jBackend backend) {
//Full array product reduction //Full array product reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
@ -442,7 +464,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testProdAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProdAlongDimensionBP(Nd4jBackend backend) {
//dL/dIn_i = dL/dOut * dOut/dIn_i //dL/dIn_i = dL/dOut * dOut/dIn_i
// = dL/dOut * d(prod(in))/dIn_i // = dL/dOut * d(prod(in))/dIn_i
// = dL/dOut * (prod(in) / in_i) // = dL/dOut * (prod(in) / in_i)
@ -498,7 +522,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStdevBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevBP(Nd4jBackend backend) {
//If out = stdev(in) then: //If out = stdev(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = (in_i-mean)/(stdev * (n-1)) //dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
@ -534,7 +560,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStdevBP_Rank1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevBP_Rank1(Nd4jBackend backend) {
INDArray dLdOut = Nd4j.scalar(0.5); INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
double stdev = preReduceInput.stdNumber(true).doubleValue(); double stdev = preReduceInput.stdNumber(true).doubleValue();
@ -555,7 +583,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStdevAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevAlongDimensionBP(Nd4jBackend backend) {
//If out = stdev(in) then: //If out = stdev(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = (in_i-mean)/(stdev * (n-1)) //dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
@ -600,7 +630,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testVarianceBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVarianceBP(Nd4jBackend backend) {
//If out = variance(in) then: //If out = variance(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = 2*(in_i-mean)/(n-1) //dOut/dIn_i = 2*(in_i-mean)/(n-1)
@ -636,7 +668,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testVarianceAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVarianceAlongDimensionBP(Nd4jBackend backend) {
//If out = variance(in) then: //If out = variance(in) then:
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = 2*(in_i-mean)/(n-1) //dOut/dIn_i = 2*(in_i-mean)/(n-1)
@ -678,7 +712,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test @Test
public void testCumSumBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumSumBP(Nd4jBackend backend) {
//Standard case, non-reverse, non-exclusive //Standard case, non-reverse, non-exclusive
//dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i //dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i
// = sum_j dL/dOut_j * d(in_0 + ... + in_j)/dIn_i // = sum_j dL/dOut_j * d(in_0 + ... + in_j)/dIn_i
@ -748,7 +784,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test @Test
public void testNorm2Bp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2Bp(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * x/|x|_2 // = dL/dOut * x/|x|_2
@ -775,7 +813,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNorm2AlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2AlongDimensionBP(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * x/|x|_2 // = dL/dOut * x/|x|_2
@ -808,7 +848,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNorm1Bp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1Bp(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * sgn(in) // = dL/dOut * sgn(in)
@ -835,7 +877,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNorm1AlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1AlongDimensionBP(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * sgn(in) // = dL/dOut * sgn(in)
@ -867,7 +911,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNormMaxBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMaxBp(Nd4jBackend backend) {
//out = max_i (|in_i|) //out = max_i (|in_i|)
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise) // = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise)
@ -897,7 +943,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNormMaxAlongDimensionBP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMaxAlongDimensionBP(Nd4jBackend backend) {
//out = max_i (|in_i|) //out = max_i (|in_i|)
//dL/dIn = dL/dOut * dOut/dIn //dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise) // = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise)

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -76,16 +77,13 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class ReductionOpValidation extends BaseOpValidation { public class ReductionOpValidation extends BaseOpValidation {
public ReductionOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testStdev() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdev(Nd4jBackend backend) {
List<String> errors = new ArrayList<>(); List<String> errors = new ArrayList<>();
for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE)) { for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE)) {
@ -111,7 +109,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testZeroCount() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeroCount(Nd4jBackend backend) {
List<String> allFailed = new ArrayList<>(); List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 21; i++) { for (int i = 0; i < 21; i++) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -145,7 +145,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test @Test
public void testZeroFraction() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeroFraction(Nd4jBackend backend) {
List<String> allFailed = new ArrayList<>(); List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -175,7 +177,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReductionGradientsSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradientsSimple(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES //OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
//Test reductions: final and only function //Test reductions: final and only function
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -344,7 +348,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReductionGradients1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradients1(Nd4jBackend backend) {
//Test reductions: final, but *not* the only function //Test reductions: final, but *not* the only function
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -472,7 +478,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReductionGradients2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradients2(Nd4jBackend backend) {
//Test reductions: NON-final function //Test reductions: NON-final function
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -650,7 +658,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test @Test
public void testReduce3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduce3(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int d0 = 3; int d0 = 3;
@ -755,7 +765,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMoments() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMoments(Nd4jBackend backend) {
for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) { for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) {
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -787,9 +799,11 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMomentsOp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMomentsOp(Nd4jBackend backend) {
int[] axes = new int[]{0}; int[] axes = new int[]{0};
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray outMean = Nd4j.createUninitialized(new long[]{4}); INDArray outMean = Nd4j.createUninitialized(new long[]{4});
INDArray outVar = Nd4j.createUninitialized(new long[]{4}); INDArray outVar = Nd4j.createUninitialized(new long[]{4});
@ -804,7 +818,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNormalizeMomentsOp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormalizeMomentsOp(Nd4jBackend backend) {
INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10); INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
INDArray ssSum = data.sum(0); INDArray ssSum = data.sum(0);
INDArray ssSqSum = data.mul(data).sum(0); INDArray ssSqSum = data.mul(data).sum(0);
@ -824,7 +840,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testAllAny() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllAny(Nd4jBackend backend) {
INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4); INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4);
INDArray allOnes = Nd4j.ones(DataType.FLOAT, 3, 4); INDArray allOnes = Nd4j.ones(DataType.FLOAT, 3, 4);
@ -852,7 +870,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testIndexAccum() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexAccum(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/); List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/);
@ -941,7 +961,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test @Test
public void testReduce3_2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduce3_2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int d0 = 3; int d0 = 3;
@ -1039,7 +1061,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReductionsBackwards() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionsBackwards(Nd4jBackend backend) {
// for (int i = 0; i < 7; i++) { // for (int i = 0; i < 7; i++) {
int i=5; int i=5;
{ {
@ -1108,6 +1132,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttention(){ public void testDotProductAttention(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1127,12 +1153,14 @@ public class ReductionOpValidation extends BaseOpValidation {
t.norm1("out"); t.norm1("out");
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expectedOutput("out", finalOut) .expectedOutput("out", finalOut)
.gradientCheck(true)); .gradientCheck(true));
assertNull(err); assertNull(err);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionWithMask(){ public void testDotProductAttentionWithMask(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1163,6 +1191,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionMultiHeadInputWithMask(){ public void testDotProductAttentionMultiHeadInputWithMask(){
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
@ -1194,6 +1224,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionMultiHeadInput(){ public void testDotProductAttentionMultiHeadInput(){
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
@ -1221,6 +1253,8 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiHeadedDotProductAttention(){ public void testMultiHeadedDotProductAttention(){
final INDArray k = Nd4j.rand(new int[]{10, 4, 5}); final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
final INDArray v = Nd4j.rand(new int[]{10, 4, 5}); final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
@ -1272,6 +1306,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionWeirdInputs(){ public void testDotProductAttentionWeirdInputs(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1309,6 +1345,8 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiHeadedDotProductAttentionWeirdInputs(){ public void testMultiHeadedDotProductAttentionWeirdInputs(){
final INDArray k = Nd4j.rand(new int[]{10, 4, 5}); final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
final INDArray v = Nd4j.rand(new int[]{10, 4, 5}); final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
@ -1366,7 +1404,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
} }
@Test @Test
public void testSufficientStatisticsOp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSufficientStatisticsOp(Nd4jBackend backend) {
INDArray data = Nd4j.createFromArray(new double[]{ INDArray data = Nd4j.createFromArray(new double[]{
5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1., 5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5
@ -1392,7 +1432,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStandardDeviation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardDeviation(Nd4jBackend backend) {
for (boolean keepDims : new boolean[]{false, true}) { for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1419,7 +1461,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSquaredNorm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSquaredNorm(Nd4jBackend backend) {
for (boolean keepDims : new boolean[]{false, true}) { for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1442,7 +1486,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testShannonEntropy() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShannonEntropy(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695 OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1462,7 +1508,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testEntropy() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEntropy(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1481,7 +1529,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testAMean() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAMean(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1502,7 +1552,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMean() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMean(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1523,7 +1575,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNorm1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1544,7 +1598,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNorm2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1565,7 +1621,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testNormMax() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMax(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1586,7 +1644,9 @@ public class ReductionOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSoftmaxCrossEntropyWithLogitsLoss() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxCrossEntropyWithLogitsLoss(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -43,12 +45,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j @Slf4j
public class RnnOpValidation extends BaseOpValidation { public class RnnOpValidation extends BaseOpValidation {
public RnnOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testRnnBlockCell(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRnnBlockCell(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int mb = 2; int mb = 2;
int nIn = 3; int nIn = 3;
@ -147,7 +148,9 @@ public class RnnOpValidation extends BaseOpValidation {
@Test @Test
public void testRnnBlockCellManualTFCompare() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRnnBlockCellManualTFCompare(Nd4jBackend backend) {
//Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS" //Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS"
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -209,6 +212,8 @@ public class RnnOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGRUCell(){ public void testGRUCell(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int mb = 2; int mb = 2;

View File

@ -28,6 +28,8 @@ import lombok.val;
import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.LUDecomposition;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -67,9 +69,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
@Slf4j @Slf4j
public class ShapeOpValidation extends BaseOpValidation { public class ShapeOpValidation extends BaseOpValidation {
public ShapeOpValidation(Nd4jBackend backend) {
super(backend);
}
/* /*
To test: To test:
@ -83,7 +82,9 @@ public class ShapeOpValidation extends BaseOpValidation {
*/ */
@Test @Test
public void testConcat() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat(Nd4jBackend backend) {
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2}; // int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
int[] concatDim = new int[]{0, 0, 0}; int[] concatDim = new int[]{0, 0, 0};
List<List<int[]>> origShapes = new ArrayList<>(); List<List<int[]>> origShapes = new ArrayList<>();
@ -123,7 +124,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReshapeGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshapeGradient(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6873 //https://github.com/deeplearning4j/deeplearning4j/issues/6873
int[] origShape = new int[]{3, 4, 5}; int[] origShape = new int[]{3, 4, 5};
@ -159,7 +162,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testPermuteGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteGradient(Nd4jBackend backend) {
int[] origShape = new int[]{3, 4, 5}; int[] origShape = new int[]{3, 4, 5};
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -197,6 +202,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRank(){ public void testRank(){
List<long[]> inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5}); List<long[]> inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5});
@ -224,7 +231,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testExpandDimsGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExpandDimsGradient(Nd4jBackend backend) {
val origShape = new long[]{3, 4}; val origShape = new long[]{3, 4};
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -280,7 +289,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSqueezeGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSqueezeGradient(Nd4jBackend backend) {
val origShape = new long[]{3, 4, 5}; val origShape = new long[]{3, 4, 5};
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -344,7 +355,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testSliceGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSliceGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
//Order here: original shape, begin, size //Order here: original shape, begin, size
@ -434,7 +447,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSliceGradient() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
//Order here: original shape, begin, size //Order here: original shape, begin, size
@ -497,7 +512,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testMerge() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMerge(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -573,7 +590,7 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test() @Test()
public void testStack() { public void testStack(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -664,7 +681,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testUnStack() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnStack(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -752,7 +771,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTile() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTile(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<int[]> tileArg = Arrays.asList( List<int[]> tileArg = Arrays.asList(
@ -824,6 +845,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTileBp(){ public void testTileBp(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -857,6 +880,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTileBp2(){ public void testTileBp2(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -891,7 +916,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testReshape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshape(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4); INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4);
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
@ -907,7 +934,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReshape2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshape2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int[] origShape = new int[]{3, 4, 5}; int[] origShape = new int[]{3, 4, 5};
@ -930,7 +959,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testTranspose() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTranspose(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
@ -942,6 +973,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTransposeOp(){ public void testTransposeOp(){
INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3); INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3);
@ -955,7 +988,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShape(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
val shape = new long[]{2, 3}; val shape = new long[]{2, 3};
SDVariable x = sameDiff.var("x", shape); SDVariable x = sameDiff.var("x", shape);
@ -970,7 +1005,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSize() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSize(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
val shape = new long[]{2, 3}; val shape = new long[]{2, 3};
SDVariable x = sameDiff.var("x", DataType.FLOAT, shape); SDVariable x = sameDiff.var("x", DataType.FLOAT, shape);
@ -984,7 +1021,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDiagShapeFn() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagShapeFn(Nd4jBackend backend) {
INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4); INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4);
OpTestCase op = new OpTestCase(new DiagPart(i, null)); OpTestCase op = new OpTestCase(new DiagPart(i, null));
@ -998,6 +1037,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute(){ public void testPermute(){
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5); INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
INDArray exp = in.permute(0,1,2); //No op INDArray exp = in.permute(0,1,2); //No op
@ -1012,6 +1053,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute2(){ public void testPermute2(){
for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) { for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) {
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5); INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
@ -1032,6 +1075,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConstant(){ public void testConstant(){
//OpValidationSuite.ignoreFailing(); //OpValidationSuite.ignoreFailing();
@ -1059,6 +1104,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnstackEdgeCase2(){ public void testUnstackEdgeCase2(){
for( int i=0; i<3; i++ ) { for( int i=0; i<3; i++ ) {
@ -1073,7 +1120,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void invertPermutation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void invertPermutation(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new float[] {3, 4, 0, 2, 1}).castTo(DataType.INT); INDArray ia = Nd4j.create(new float[] {3, 4, 0, 2, 1}).castTo(DataType.INT);
@ -1090,6 +1139,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherNd(){ public void testGatherNd(){
List<INDArray> indices = new ArrayList<>(); List<INDArray> indices = new ArrayList<>();
@ -1128,7 +1179,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testReverseSequence() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReverseSequence(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
float[] input_data = new float[]{ float[] input_data = new float[]{
1, 2, 3, 1, 2, 3,
@ -1174,6 +1227,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant(){ public void testMatrixDeterminant(){
OpValidationSuite.ignoreFailing(); //Gradient check failing OpValidationSuite.ignoreFailing(); //Gradient check failing
@ -1195,6 +1250,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeterminant22(){ public void testDeterminant22(){
OpValidationSuite.ignoreFailing(); //Gradient check failing OpValidationSuite.ignoreFailing(); //Gradient check failing
@ -1219,6 +1276,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant3(){ public void testMatrixDeterminant3(){
OpValidationSuite.ignoreFailing(); //Gradient checks failing OpValidationSuite.ignoreFailing(); //Gradient checks failing
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1250,6 +1309,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant4(){ public void testMatrixDeterminant4(){
OpValidationSuite.ignoreFailing(); //Gradient checks failing OpValidationSuite.ignoreFailing(); //Gradient checks failing
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1270,6 +1331,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentOps(){ public void testSegmentOps(){
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
//https://github.com/deeplearning4j/deeplearning4j/issues/6952 //https://github.com/deeplearning4j/deeplearning4j/issues/6952
@ -1362,6 +1425,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentMean(){ public void testSegmentMean(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3); INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3);
INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2); INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2);
@ -1382,7 +1447,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSequenceMask() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSequenceMask(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2}); INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
// arr is not trainable, so it's constant in model // arr is not trainable, so it's constant in model
@ -1391,10 +1458,10 @@ public class ShapeOpValidation extends BaseOpValidation {
// Test with static max len // Test with static max len
int maxlen = 5; int maxlen = 5;
INDArray expected = Nd4j.create(new float[] { INDArray expected = Nd4j.create(new float[] {
1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f,
1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f,
1.f, 1.f, 0.f, 0.f, 0.f 1.f, 1.f, 0.f, 0.f, 0.f
}).reshape(3,5); }).reshape(3,5);
INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.FLOAT)); INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.FLOAT));
SDVariable result1 = sameDiff.sequenceMask(lengths, maxlen, DataType.FLOAT); SDVariable result1 = sameDiff.sequenceMask(lengths, maxlen, DataType.FLOAT);
assertArrayEquals(expected.shape(), result1.eval().shape()); assertArrayEquals(expected.shape(), result1.eval().shape());
@ -1416,6 +1483,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeshGrid(){ public void testMeshGrid(){
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -1472,6 +1541,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGather(){ public void testGather(){
List<INDArray> inArrs = new ArrayList<>(); List<INDArray> inArrs = new ArrayList<>();
List<Integer> axis = new ArrayList<>(); List<Integer> axis = new ArrayList<>();
@ -1541,7 +1612,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testGatherSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherSimple(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2}); INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2});
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
@ -1551,7 +1624,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testGatherNdSingle() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherNdSingle(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4);
INDArray arr2 = Nd4j.create(new float[]{1, 2, 3, 0, 1, 3, 1, 0, 2}, new long[]{3, 3}).castTo(DataType.INT); INDArray arr2 = Nd4j.create(new float[]{1, 2, 3, 0, 1, 3, 1, 0, 2}, new long[]{3, 3}).castTo(DataType.INT);
@ -1563,14 +1638,16 @@ public class ShapeOpValidation extends BaseOpValidation {
for (int i=0; i<3; i++){ for (int i=0; i<3; i++){
INDArray idx = arr2.get(point(i), NDArrayIndex.all()); INDArray idx = arr2.get(point(i), NDArrayIndex.all());
expected.putScalar(i, arr1.get(point(idx.getInt(0)), expected.putScalar(i, arr1.get(point(idx.getInt(0)),
point(idx.getInt(1)), point(idx.getInt(1)),
point(idx.getInt(2))).getDouble(0)); point(idx.getInt(2))).getDouble(0));
} }
assertEquals(expected, result.eval()); assertEquals(expected, result.eval());
} }
@Test @Test
public void testStack2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStack2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2); INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
@ -1581,7 +1658,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testParallelStack() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testParallelStack(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2); INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
@ -1593,7 +1672,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testUnStack2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnStack2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Nd4j.zeros(3, 2); INDArray arr1 = Nd4j.zeros(3, 2);
INDArray arr2 = Nd4j.ones(3, 2); INDArray arr2 = Nd4j.ones(3, 2);
@ -1606,7 +1687,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testPermuteSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteSimple(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3)); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3));
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
@ -1617,7 +1700,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testConcat2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(4, 8, 4)).reshape(1,4); INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(4, 8, 4)).reshape(1,4);
@ -1628,7 +1713,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTile2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTile2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4)); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4));
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
@ -1641,7 +1728,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testSlice2d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice2d(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1657,7 +1746,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testSlice3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice3d(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1672,7 +1763,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSlice2dBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSlice2dBasic(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1690,7 +1783,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testStridedSliceBeginEndMask() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceBeginEndMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1705,7 +1800,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSliceEllipsisMask() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEllipsisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
@ -1722,7 +1819,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSliceNewAxisMask() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceNewAxisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
@ -1735,7 +1834,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSliceNewAxisMask2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceNewAxisMask2(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
@ -1746,7 +1847,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStridedSliceShrinkAxisMask() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceShrinkAxisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1763,7 +1866,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSizeAt_1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSizeAt_1(Nd4jBackend backend) {
val array = Nd4j.create(10, 20, 30); val array = Nd4j.create(10, 20, 30);
val exp = Nd4j.scalar(DataType.LONG, 20); val exp = Nd4j.scalar(DataType.LONG, 20);
@ -1777,6 +1882,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(){ public void testEye(){
int[] rows = new int[]{3,3,3,3}; int[] rows = new int[]{3,3,3,3};
int[] cols = new int[]{3,2,2,2}; int[] cols = new int[]{3,2,2,2};
@ -1815,6 +1922,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplit1(){ public void testSplit1(){
INDArray in = Nd4j.linspace(1,10,10).reshape(10); INDArray in = Nd4j.linspace(1,10,10).reshape(10);
INDArray axis = Nd4j.scalar(-1); INDArray axis = Nd4j.scalar(-1);
@ -1833,6 +1942,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplit2(){ public void testSplit2(){
INDArray in = Nd4j.linspace(1,24,24).reshape(3,8); INDArray in = Nd4j.linspace(1,24,24).reshape(3,8);
INDArray axis = Nd4j.scalar(-1); INDArray axis = Nd4j.scalar(-1);
@ -1851,6 +1962,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDistancesExec(){ public void testDistancesExec(){
//https://github.com/deeplearning4j/deeplearning4j/issues/7001 //https://github.com/deeplearning4j/deeplearning4j/issues/7001
for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) { for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) {
@ -1906,6 +2019,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionShape(){ public void testReductionShape(){
INDArray shape = Nd4j.createFromArray(4,2); INDArray shape = Nd4j.createFromArray(4,2);
@ -1924,6 +2039,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void gatherTest(){ public void gatherTest(){
INDArray in = Nd4j.createFromArray(new double[][]{ INDArray in = Nd4j.createFromArray(new double[][]{
{1,2,3,4,5}, {1,2,3,4,5},
@ -1943,6 +2060,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSliceShape(){ public void testSliceShape(){
INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT); INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT);
@ -1964,6 +2083,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testWhereAllFalse(){ public void testWhereAllFalse(){
INDArray in = Nd4j.create(DataType.BOOL, 1917); INDArray in = Nd4j.create(DataType.BOOL, 1917);
DynamicCustomOp op = DynamicCustomOp.builder("Where") DynamicCustomOp op = DynamicCustomOp.builder("Where")
@ -1978,6 +2099,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherScalar(){ public void testGatherScalar(){
INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100); INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100);
INDArray indices = Nd4j.scalar(0); INDArray indices = Nd4j.scalar(0);
@ -2002,6 +2125,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCastEmpty(){ public void testCastEmpty(){
INDArray emptyLong = Nd4j.empty(DataType.LONG); INDArray emptyLong = Nd4j.empty(DataType.LONG);
int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h
@ -2018,6 +2143,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherEmpty(){ public void testGatherEmpty(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -2050,6 +2177,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplitEmpty(){ public void testSplitEmpty(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -2087,6 +2216,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatEmpty(){ public void testConcatEmpty(){
/* /*
TF behaviour with concatenatioun of empty arrays: TF behaviour with concatenatioun of empty arrays:
@ -2136,6 +2267,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatEmpty2(){ public void testConcatEmpty2(){
INDArray empty10a = Nd4j.create(DataType.INT, 1, 0); INDArray empty10a = Nd4j.create(DataType.INT, 1, 0);
INDArray empty10b = Nd4j.create(DataType.INT, 1, 0); INDArray empty10b = Nd4j.create(DataType.INT, 1, 0);
@ -2168,6 +2301,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyGather(){ public void testEmptyGather(){
/* /*
tf.reset_default_graph() tf.reset_default_graph()
@ -2200,6 +2335,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastDynamicShape1(){ public void testBroadcastDynamicShape1(){
//Test case: [2,1] and [4]: expect [2,4] //Test case: [2,1] and [4]: expect [2,4]
@ -2221,6 +2358,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastDynamicShape2(){ public void testBroadcastDynamicShape2(){
//Test case: [2,1,4] and [2,2,4]: expect [2,2,4] //Test case: [2,1,4] and [2,2,4]: expect [2,2,4]
@ -2243,6 +2382,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceShrinkAxis(){ public void testStridedSliceShrinkAxis(){
INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2); INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2);
INDArray begin = Nd4j.createFromArray(2); INDArray begin = Nd4j.createFromArray(2);
@ -2268,6 +2409,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEmpty(){ public void testStridedSliceEmpty(){
INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask! INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask!
@ -2290,6 +2433,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEdgeCase(){ public void testStridedSliceEdgeCase(){
INDArray in = Nd4j.scalar(10).reshape(1); //Int [1] INDArray in = Nd4j.scalar(10).reshape(1); //Int [1]
INDArray begin = Nd4j.ones(DataType.INT, 1); INDArray begin = Nd4j.ones(DataType.INT, 1);
@ -2315,6 +2460,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptySlice1(){ public void testEmptySlice1(){
INDArray in = Nd4j.createFromArray(38); INDArray in = Nd4j.createFromArray(38);
INDArray begin = Nd4j.createFromArray(1); INDArray begin = Nd4j.createFromArray(1);
@ -2334,6 +2481,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptySlice2(){ public void testEmptySlice2(){
INDArray in = Nd4j.createFromArray(38); INDArray in = Nd4j.createFromArray(38);
INDArray begin = Nd4j.createFromArray(0); INDArray begin = Nd4j.createFromArray(0);
@ -2353,6 +2502,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFill(){ public void testFill(){
INDArray shape = Nd4j.createFromArray(0,4); INDArray shape = Nd4j.createFromArray(0,4);
@ -2372,6 +2523,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFill2(){ public void testFill2(){
INDArray shape = Nd4j.createFromArray(0,4); INDArray shape = Nd4j.createFromArray(0,4);
@ -2389,6 +2542,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteShapeDynamicAxis(){ public void testPermuteShapeDynamicAxis(){
DynamicCustomOp op = DynamicCustomOp.builder("permute") DynamicCustomOp op = DynamicCustomOp.builder("permute")
@ -2418,6 +2573,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGather2(){ public void testGather2(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3)); SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3));
@ -2437,6 +2594,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute3(){ public void testPermute3(){
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0); INDArray permute = Nd4j.createFromArray(1,0);
@ -2455,6 +2614,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute4(){ public void testPermute4(){
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0); INDArray permute = Nd4j.createFromArray(1,0);
@ -2485,6 +2646,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvertPermutation(){ public void testInvertPermutation(){
DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation") DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation")
.addInputs(Nd4j.createFromArray(1, 0)) .addInputs(Nd4j.createFromArray(1, 0))
@ -2492,7 +2655,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBroadcastInt1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastInt1(Nd4jBackend backend) {
INDArray out = Nd4j.create(DataType.INT, 1); INDArray out = Nd4j.create(DataType.INT, 1);
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape") DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
@ -2505,6 +2670,8 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastInt2(){ public void testBroadcastInt2(){
INDArray out = Nd4j.create(DataType.INT, 2); INDArray out = Nd4j.create(DataType.INT, 2);
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape") DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
@ -2544,7 +2711,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test @Test
public void testMergeMaxIndex() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeMaxIndex(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2561,7 +2730,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTriOp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriOp(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable out = new Tri(sd, DataType.INT32, 3, 5, 2).outputVariable(); SDVariable out = new Tri(sd, DataType.INT32, 3, 5, 2).outputVariable();
@ -2573,7 +2744,9 @@ public class ShapeOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTriuOp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriuOp(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {7,8,9},{10,11,12}})); SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {7,8,9},{10,11,12}}));
@ -2581,8 +2754,8 @@ public class ShapeOpValidation extends BaseOpValidation {
out.markAsLoss(); out.markAsLoss();
INDArray expected = Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {0,8,9},{0,0,12}}); INDArray expected = Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {0,8,9},{0,0,12}});
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expectedOutput("triu", expected) .expectedOutput("triu", expected)
.gradientCheck(true)); .gradientCheck(true));
assertNull(err); assertNull(err);
} }

View File

@ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -94,9 +96,6 @@ public class TransformOpValidation extends BaseOpValidation {
private DataType initialType; private DataType initialType;
public TransformOpValidation(Nd4jBackend backend) {
super(backend);
}
@BeforeEach @BeforeEach
public void before() { public void before() {
@ -120,7 +119,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testScalarOps() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarOps(Nd4jBackend backend) {
int d0 = 2; int d0 = 2;
int d1 = 3; int d1 = 3;
int d2 = 4; int d2 = 4;
@ -217,7 +218,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testScalarMulCF() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarMulCF(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
INDArray outC = Nd4j.createUninitialized(3, 4); INDArray outC = Nd4j.createUninitialized(3, 4);
@ -231,7 +234,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testScalarMulCF2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarMulCF2(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
@ -242,7 +247,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testCross() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCross(Nd4jBackend backend) {
INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3}); INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3});
INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3}); INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3});
@ -270,7 +277,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSpaceToDepth() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpaceToDepth(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
int miniBatch = 128; int miniBatch = 128;
@ -298,7 +307,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDepthToSpace() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthToSpace(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
int miniBatch = 128; int miniBatch = 128;
@ -325,7 +336,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBatchToSpace() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBatchToSpace(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863 //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);
@ -362,7 +375,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSpaceToBatch() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpaceToBatch(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863 //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(7331); Nd4j.getRandom().setSeed(7331);
@ -400,7 +415,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDynamicPartition() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{4, 3, 5, 7, 8, 0}); INDArray ia = Nd4j.create(new double[]{4, 3, 5, 7, 8, 0});
@ -440,7 +457,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDynamicPartition2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition2(Nd4jBackend backend) {
INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
@ -458,7 +477,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDynamicStitch() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicStitch(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{5, 1, 3}, new long[]{3}); INDArray ia = Nd4j.create(new double[]{5, 1, 3}, new long[]{3});
@ -495,7 +516,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDiag() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiag(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{1, 2}, new int[]{2}); INDArray ia = Nd4j.create(new double[]{1, 2}, new int[]{2});
@ -521,7 +544,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testDiagPart() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagPart(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray input = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4); INDArray input = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
@ -540,7 +565,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testEye() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(Nd4jBackend backend) {
int[] rows = new int[]{3, 3, 3, 3}; int[] rows = new int[]{3, 3, 3, 3};
int[] cols = new int[]{3, 2, 2, 2}; int[] cols = new int[]{3, 2, 2, 2};
int[][] batch = new int[][]{{}, {}, {4}, {3, 3}}; int[][] batch = new int[][]{{}, {}, {4}, {3, 3}};
@ -574,7 +601,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testEyeShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEyeShape(Nd4jBackend backend) {
DynamicCustomOp dco = DynamicCustomOp.builder("eye") DynamicCustomOp dco = DynamicCustomOp.builder("eye")
.addIntegerArguments(3, 3) .addIntegerArguments(3, 3)
//.addIntegerArguments(-99,3,3) //Also fails //.addIntegerArguments(-99,3,3) //Also fails
@ -586,7 +615,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTransforms() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTransforms(Nd4jBackend backend) {
//Test transforms (non-pairwise) //Test transforms (non-pairwise)
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1074,7 +1105,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testPairwiseTransforms() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPairwiseTransforms(Nd4jBackend backend) {
/* /*
add, sub, mul, div, rsub, rdiv add, sub, mul, div, rsub, rdiv
eq, neq, gt, lt, gte, lte, or, and, xor eq, neq, gt, lt, gte, lte, or, and, xor
@ -1258,7 +1291,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testIsX() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIsX(Nd4jBackend backend) {
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
@ -1313,7 +1348,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReplaceWhereScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReplaceWhereScalar(Nd4jBackend backend) {
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
log.info("Testing condition: " + c.getClass().getSimpleName()); log.info("Testing condition: " + c.getClass().getSimpleName());
@ -1335,7 +1372,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReplaceWhereArray() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReplaceWhereArray(Nd4jBackend backend) {
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
INDArray inArr = Nd4j.rand(3, 4); INDArray inArr = Nd4j.rand(3, 4);
@ -1358,7 +1397,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLogGrad() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogGrad(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE)); SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE));
SDVariable log = sameDiff.math().log(input); SDVariable log = sameDiff.math().log(input);
@ -1369,7 +1410,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testSigmoidBackwards() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSigmoidBackwards(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
Map<String, INDArray> inputs = new HashMap<>(); Map<String, INDArray> inputs = new HashMap<>();
@ -1386,8 +1429,10 @@ public class TransformOpValidation extends BaseOpValidation {
} }
/* @Test /* @Test
public void testDepth() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepth(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
SDVariable x = sameDiff.one("one",new long[]{2,2}); SDVariable x = sameDiff.one("one",new long[]{2,2});
assertEquals(0,x.depth()); assertEquals(0,x.depth());
@ -1396,7 +1441,9 @@ public class TransformOpValidation extends BaseOpValidation {
}*/ }*/
@Test @Test
public void testRank0EdgeCase() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRank0EdgeCase(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))); SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4})));
double d0 = v1.eval().getDouble(0); double d0 = v1.eval().getDouble(0);
@ -1409,7 +1456,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testAtan2BroadcastShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAtan2BroadcastShape(Nd4jBackend backend) {
INDArray arr1 = Nd4j.create(new long[]{3, 1, 4}); INDArray arr1 = Nd4j.create(new long[]{3, 1, 4});
INDArray arr2 = Nd4j.create(new long[]{1, 2, 4}); INDArray arr2 = Nd4j.create(new long[]{1, 2, 4});
@ -1424,7 +1473,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBooleanAnd() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBooleanAnd(Nd4jBackend backend) {
Nd4j.setDataType(DataType.FLOAT); Nd4j.setDataType(DataType.FLOAT);
INDArray arr1 = Nd4j.create(new long[]{3, 4}); INDArray arr1 = Nd4j.create(new long[]{3, 4});
INDArray arr2 = Nd4j.create(new long[]{3, 4}); INDArray arr2 = Nd4j.create(new long[]{3, 4});
@ -1438,7 +1489,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testScatterOpsScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterOpsScalar(Nd4jBackend backend) {
for (String s : new String[]{"add", "sub", "mul", "div"}) { for (String s : new String[]{"add", "sub", "mul", "div"}) {
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3); INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
INDArray indices = Nd4j.scalar(5); INDArray indices = Nd4j.scalar(5);
@ -1483,7 +1536,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540") @Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540")
@Test @Test
public void testPad() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPad(Nd4jBackend backend) {
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0); INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG); INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG);
INDArray value = Nd4j.scalar(10.0); INDArray value = Nd4j.scalar(10.0);
@ -1510,7 +1565,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testMirrorPad() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPad(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
@ -1543,7 +1600,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMirrorPad2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPad2(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
@ -1569,7 +1628,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMirrorPadSymmetric() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPadSymmetric(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4); INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT);
@ -1596,7 +1657,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testUnique() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnique(Nd4jBackend backend) {
INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4}); INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4});
INDArray expUnique = Nd4j.create(new double[]{3, 4, 1, 0, 2}); INDArray expUnique = Nd4j.create(new double[]{3, 4, 1, 0, 2});
@ -1618,7 +1681,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTopK() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopK(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //Can't assume sorted here OpValidationSuite.ignoreFailing(); //Can't assume sorted here
INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8}); INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8});
@ -1647,7 +1712,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testTopK1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopK1(Nd4jBackend backend) {
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
INDArray k = Nd4j.scalar(1); INDArray k = Nd4j.scalar(1);
INDArray outValue = Nd4j.create(DataType.DOUBLE, 1); INDArray outValue = Nd4j.create(DataType.DOUBLE, 1);
@ -1668,7 +1735,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testInTopK() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInTopK(Nd4jBackend backend) {
for (int k = 4; k >= 1; k--) { for (int k = 4; k >= 1; k--) {
log.info("Testing: k=" + k); log.info("Testing: k=" + k);
INDArray in = Nd4j.linspace(1, 20, 20, DataType.DOUBLE).reshape(4, 5); INDArray in = Nd4j.linspace(1, 20, 20, DataType.DOUBLE).reshape(4, 5);
@ -1709,7 +1778,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testZeta() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeta(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182 OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182
INDArray x = Nd4j.rand(3, 4).addi(1.0); INDArray x = Nd4j.rand(3, 4).addi(1.0);
INDArray q = Nd4j.rand(3, 4); INDArray q = Nd4j.rand(3, 4);
@ -1726,7 +1797,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMaxEmptyScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxEmptyScalar(Nd4jBackend backend) {
INDArray empty = Nd4j.empty(DataType.FLOAT); INDArray empty = Nd4j.empty(DataType.FLOAT);
INDArray scalar = Nd4j.scalar(1.0f); INDArray scalar = Nd4j.scalar(1.0f);
@ -1743,7 +1816,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testBroadcastEmpty() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastEmpty(Nd4jBackend backend) {
// Nd4j.getExecutioner().enableVerboseMode(true); // Nd4j.getExecutioner().enableVerboseMode(true);
// Nd4j.getExecutioner().enableDebugMode(true); // Nd4j.getExecutioner().enableDebugMode(true);
//Check broadcast behaviour with empty arrays. The idea is to match TF import behaviour, for import //Check broadcast behaviour with empty arrays. The idea is to match TF import behaviour, for import
@ -1833,7 +1908,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStandardize() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardize(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4}); final INDArray random = Nd4j.rand(new int[]{10, 4});
final int[] axis = new int[]{1}; final int[] axis = new int[]{1};
@ -1854,7 +1931,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStandardizeOP() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardizeOP(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4}); final INDArray random = Nd4j.rand(new int[]{10, 4});
final int[] axis = new int[]{1}; final int[] axis = new int[]{1};
@ -1869,7 +1948,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testStandardizeNoDeviation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardizeNoDeviation(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4}); final INDArray random = Nd4j.rand(new int[]{10, 4});
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
random.putScalar(1, i, 7); random.putScalar(1, i, 7);
@ -1895,7 +1976,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMatMulTensor() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatMulTensor(Nd4jBackend backend) {
final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5}); final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5});
final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6}); final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6});
@ -1915,7 +1998,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMatMulTensorTranspose() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatMulTensorTranspose(Nd4jBackend backend) {
for (boolean transposeA : new boolean[]{false, true}) { for (boolean transposeA : new boolean[]{false, true}) {
for (boolean transposeB : new boolean[]{false, true}) { for (boolean transposeB : new boolean[]{false, true}) {
for (boolean transposeResult : new boolean[]{false, true}) { for (boolean transposeResult : new boolean[]{false, true}) {
@ -2008,7 +2093,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testSoftmaxCF() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxCF(Nd4jBackend backend) {
INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5); INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5);
INDArray arrF = arrC.dup('f'); INDArray arrF = arrC.dup('f');
@ -2029,7 +2116,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLogSumExp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogSumExp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4); INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2044,7 +2133,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testLogSumExp2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogSumExp2(Nd4jBackend backend) {
for (int dim = 0; dim <= 2; dim++) { for (int dim = 0; dim <= 2; dim++) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -2065,7 +2156,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testCRELU() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCRELU(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2); INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2);
@ -2084,7 +2177,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testClipByAvgNorm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByAvgNorm(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2, 2); INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2, 2);
@ -2105,7 +2200,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testEmbeddingLookup() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmbeddingLookup(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2118,49 +2215,53 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testImageResize() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testImageResize(Nd4jBackend backend) {
//TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea //TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea
for (ImageResizeMethod method : ImageResizeMethod.values()) { for (ImageResizeMethod method : ImageResizeMethod.values()) {
if (method==ImageResizeMethod.ResizeLanczos5 || method==ImageResizeMethod.ResizeArea || method == ImageResizeMethod.ResizeMitchelcubic) if (method==ImageResizeMethod.ResizeLanczos5 || method==ImageResizeMethod.ResizeArea || method == ImageResizeMethod.ResizeMitchelcubic)
{continue;} {continue;}
log.info("Trying {}", method); log.info("Trying {}", method);
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
boolean preserveAspectRatio = true; boolean preserveAspectRatio = true;
boolean antialias = true; boolean antialias = true;
SDVariable inputImage = sd.var(Nd4j.rand(DataType.FLOAT, 1, 5, 5, 3)); SDVariable inputImage = sd.var(Nd4j.rand(DataType.FLOAT, 1, 5, 5, 3));
// NHWC format // NHWC format
long[] expectedShape = new long[]{1, 3, 3, 3}; long[] expectedShape = new long[]{1, 3, 3, 3};
SDVariable requestedSize = sd.constant(Nd4j.createFromArray( new long[]{3, 3})); SDVariable requestedSize = sd.constant(Nd4j.createFromArray( new long[]{3, 3}));
Function<INDArray, String> checkFunction = in -> { Function<INDArray, String> checkFunction = in -> {
boolean shapeOk = Arrays.equals(expectedShape, in.shape()); boolean shapeOk = Arrays.equals(expectedShape, in.shape());
if (shapeOk) return null; if (shapeOk) return null;
return "Failed: shape differs - expected " + Arrays.toString(expectedShape) + " vs " + Arrays.toString(in.shape()) + " on method " + method; return "Failed: shape differs - expected " + Arrays.toString(expectedShape) + " vs " + Arrays.toString(in.shape()) + " on method " + method;
}; };
SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable().std(true); SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable().std(true);
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.gradientCheck(false) .gradientCheck(false)
.expected("image_resize", checkFunction)); .expected("image_resize", checkFunction));
assertNull(err); assertNull(err);
} }
} }
@Test @Test
public void testMaximumBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaximumBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2177,7 +2278,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMergeAddBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeAddBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2194,7 +2297,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testMergeMaxBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeMaxBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2212,7 +2317,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test @Test
public void testMergeAvgBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeAvgBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2229,7 +2336,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testReverseBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReverseBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2243,7 +2352,9 @@ public class TransformOpValidation extends BaseOpValidation {
} }
@Test @Test
public void testUpsampling3dBp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUpsampling3dBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for (boolean dataformat : new boolean[]{true, false}) { for (boolean dataformat : new boolean[]{true, false}) {

View File

@ -24,8 +24,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
@ -36,11 +37,8 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
public class ConvConfigTests extends BaseNd4jTest { public class ConvConfigTests extends BaseNd4jTestWithBackends {
public ConvConfigTests(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -48,7 +46,9 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
@Test @Test
public void testDeConv2D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv2D(Nd4jBackend backend){
DeConv2DConfig.builder().kH(2).kW(4).build(); DeConv2DConfig.builder().kH(2).kW(4).build();
try{ try{
@ -108,8 +108,10 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
} }
@Test @Test
public void testConv2D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2D(Nd4jBackend backend){
Conv2DConfig.builder().kH(2).kW(4).build(); Conv2DConfig.builder().kH(2).kW(4).build();
try{ try{
@ -169,8 +171,10 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
} }
@Test @Test
public void testPooling2D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPooling2D(Nd4jBackend backend){
Pooling2DConfig.builder().kH(2).kW(4).build(); Pooling2DConfig.builder().kH(2).kW(4).build();
try{ try{
@ -230,8 +234,10 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
} }
@Test @Test
public void testDeConv3D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv3D(Nd4jBackend backend){
DeConv3DConfig.builder().kH(2).kW(4).kD(3).build(); DeConv3DConfig.builder().kH(2).kW(4).kD(3).build();
try{ try{
@ -319,8 +325,10 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
} }
@Test @Test
public void testConv3D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3D(Nd4jBackend backend){
Conv3DConfig.builder().kH(2).kW(4).kD(3).build(); Conv3DConfig.builder().kH(2).kW(4).kD(3).build();
try{ try{
@ -410,8 +418,10 @@ public class ConvConfigTests extends BaseNd4jTest {
@Test @Test
public void testPooling3D(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPooling3D(Nd4jBackend backend){
Pooling3DConfig.builder().kH(2).kW(4).kD(3).build(); Pooling3DConfig.builder().kH(2).kW(4).kD(3).build();
try{ try{
@ -499,7 +509,9 @@ public class ConvConfigTests extends BaseNd4jTest {
} }
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1D(){ public void testConv1D(){
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build(); Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();

View File

@ -23,8 +23,10 @@ package org.nd4j.autodiff.samediff;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Disabled("AB 2019/05/21 - JVM Crash on ppc64 - Issue #7657") @Disabled("AB 2019/05/21 - JVM Crash on ppc64 - Issue #7657")
public class FailingSameDiffTests extends BaseNd4jTest { public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
public FailingSameDiffTests(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -52,7 +51,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
} }
@Test @Test
public void testEye(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(Nd4jBackend backend){
//OpValidationSuite.ignoreFailing(); //OpValidationSuite.ignoreFailing();
INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3}); INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3});
List<INDArray> stack = new ArrayList<>(); List<INDArray> stack = new ArrayList<>();
@ -68,7 +69,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
} }
@Test @Test
public void testEyeShape(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEyeShape(Nd4jBackend backend){
val dco = DynamicCustomOp.builder("eye") val dco = DynamicCustomOp.builder("eye")
.addIntegerArguments(3,3) .addIntegerArguments(3,3)
//.addIntegerArguments(-99,3,3) //Also fails //.addIntegerArguments(-99,3,3) //Also fails
@ -80,7 +83,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
} }
@Test @Test
public void testExecutionDifferentShapesTransform(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecutionDifferentShapesTransform(Nd4jBackend backend){
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4)); SDVariable in = sd.var("in", Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4));
@ -101,7 +106,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
} }
@Test @Test
public void testDropout() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDropout(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
double p = 0.5; double p = 0.5;
@ -114,7 +121,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
} }
@Test @Test
public void testExecutionDifferentShapesDynamicCustom(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecutionDifferentShapesDynamicCustom(Nd4jBackend backend){
OpValidationSuite.ignoreFailing(); OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();

View File

@ -26,13 +26,15 @@ import org.apache.commons.io.IOUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.graph.FlatConfiguration; import org.nd4j.graph.FlatConfiguration;
import org.nd4j.graph.FlatGraph; import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode; import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable; import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair; import org.nd4j.graph.IntPair;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
@ -70,11 +72,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
public class FlatBufferSerdeTest extends BaseNd4jTest { public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
public FlatBufferSerdeTest(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -84,7 +83,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test @Test
public void testBasic(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4); INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() ); SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
@ -139,7 +140,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
} }
@Test @Test
public void testSimple(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
for( int i = 0; i < 10; i++ ) { for( int i = 0; i < 10; i++ ) {
for(boolean execFirst : new boolean[]{false, true}) { for(boolean execFirst : new boolean[]{false, true}) {
log.info("Starting test: i={}, execFirst={}", i, execFirst); log.info("Starting test: i={}, execFirst={}", i, execFirst);
@ -268,7 +271,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test @Test
public void testTrainingSerde(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
//Ensure 2 things: //Ensure 2 things:
//1. Training config is serialized/deserialized correctly //1. Training config is serialized/deserialized correctly
@ -352,7 +357,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test @Test
public void pooling3DSerialization(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void pooling3DSerialization(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28); SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28);
@ -372,7 +379,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
} }
@Test @Test
public void pooling3DSerialization2(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void pooling3DSerialization2(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28); SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28);

View File

@ -22,12 +22,14 @@ package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil; import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
import org.nd4j.autodiff.samediff.transform.OpPredicate; import org.nd4j.autodiff.samediff.transform.OpPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraph; import org.nd4j.autodiff.samediff.transform.SubGraph;
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate; import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor; import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
public class GraphTransformUtilTests extends BaseNd4jTest { public class GraphTransformUtilTests extends BaseNd4jTestWithBackends {
public GraphTransformUtilTests(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -54,7 +53,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
} }
@Test @Test
public void testBasic(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasic(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 32); SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 32);
@ -93,7 +94,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
} }
@Test @Test
public void testSubgraphReplace1(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSubgraphReplace1(Nd4jBackend backend){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 4); SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 4);

View File

@ -21,8 +21,10 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr; import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -32,11 +34,8 @@ import java.lang.reflect.Field;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class MemoryMgrTest extends BaseNd4jTest { public class MemoryMgrTest extends BaseNd4jTestWithBackends {
public MemoryMgrTest(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -44,7 +43,9 @@ public class MemoryMgrTest extends BaseNd4jTest {
} }
@Test @Test
public void testArrayReuseTooLarge() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArrayReuseTooLarge(Nd4jBackend backend) throws Exception {
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr(); ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
Field f = ArrayCacheMemoryMgr.class.getDeclaredField("maxCacheBytes"); Field f = ArrayCacheMemoryMgr.class.getDeclaredField("maxCacheBytes");
@ -97,7 +98,7 @@ public class MemoryMgrTest extends BaseNd4jTest {
assertEquals(10, mmgr.getLruCacheValues().size()); assertEquals(10, mmgr.getLruCacheValues().size());
//now, allocate some values: //now, allocate some values:
for( int i=1; i<=10; i++ ) { for( int i = 1; i <= 10; i++) {
INDArray a1 = mmgr.allocate(true, DataType.FLOAT, 25); INDArray a1 = mmgr.allocate(true, DataType.FLOAT, 25);
assertEquals(1000 - i * 100, mmgr.getCurrentCacheSize()); assertEquals(1000 - i * 100, mmgr.getCurrentCacheSize());
assertEquals(1000 - i * 100, as.getBytesSum()); assertEquals(1000 - i * 100, as.getBytesSum());
@ -116,10 +117,12 @@ public class MemoryMgrTest extends BaseNd4jTest {
} }
@Test @Test
public void testManyArrays(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testManyArrays(Nd4jBackend backend){
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr(); ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
for( int i=0; i<1000; i++ ){ for( int i = 0; i < 1000; i++) {
mmgr.release(Nd4j.scalar(0)); mmgr.release(Nd4j.scalar(0));
} }
@ -127,7 +130,7 @@ public class MemoryMgrTest extends BaseNd4jTest {
assertEquals(1000, mmgr.getLruCache().size()); assertEquals(1000, mmgr.getLruCache().size());
assertEquals(1000, mmgr.getLruCacheValues().size()); assertEquals(1000, mmgr.getLruCacheValues().size());
for( int i=0; i<1000; i++ ){ for( int i = 0; i < 1000; i++ ){
mmgr.release(Nd4j.scalar(0)); mmgr.release(Nd4j.scalar(0));
} }

View File

@ -21,9 +21,11 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,19 +37,18 @@ import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class NameScopeTests extends BaseNd4jTest { public class NameScopeTests extends BaseNd4jTestWithBackends {
public NameScopeTests(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering() {
return 'c'; return 'c';
} }
@Test @Test
public void testVariableNameScopesBasic(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVariableNameScopesBasic(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable v = sd.var("x"); SDVariable v = sd.var("x");
@ -73,7 +74,9 @@ public class NameScopeTests extends BaseNd4jTest {
} }
@Test @Test
public void testOpFieldsAndNames(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpFieldsAndNames(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable x = sd.var("x", DataType.FLOAT, 1); SDVariable x = sd.var("x", DataType.FLOAT, 1);
@ -151,7 +154,9 @@ public class NameScopeTests extends BaseNd4jTest {
} }
@Test @Test
public void testNoNesting(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoNesting(Nd4jBackend backend) {
SameDiff SD = SameDiff.create(); SameDiff SD = SameDiff.create();
SDVariable a = SD.constant(4); SDVariable a = SD.constant(4);
@ -168,7 +173,9 @@ public class NameScopeTests extends BaseNd4jTest {
} }
@Test @Test
public void testNoTesting2(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoTesting2(Nd4jBackend backend) {
SameDiff SD = SameDiff.create(); SameDiff SD = SameDiff.create();
SDVariable a = SD.constant(4); SDVariable a = SD.constant(4);

View File

@ -21,21 +21,16 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.imports.tfgraphs.TFGraphTestZooModels;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.AtomicBoolean; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.resources.Resources;
import java.io.File;
import java.nio.file.Path;
import java.util.Collections; import java.util.Collections;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore; import java.util.concurrent.Semaphore;
@ -55,7 +50,9 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
} }
@Test @Test
public void testSimple() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(Nd4jBackend backend) throws Exception {
int nThreads = 4; int nThreads = 4;
int nRuns = 1000; int nRuns = 1000;
@ -103,48 +100,6 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
} }
} }
@Test
@Disabled //2020/03/24 AB - https://github.com/eclipse/deeplearning4j/issues/8802
public void testMobilenet(@TempDir Path testDir) throws Exception {
TFGraphTestZooModels.currentTestDir = testDir.toFile();
File f = Resources.asFile("tf_graphs/zoo_models/mobilenet_v2_1.0_224/tf_model.txt");
SameDiff sd = TFGraphTestZooModels.LOADER.apply(f, "mobilenet_v2_1.0_224");
// System.out.println(sd.summary());
int nThreads = 4;
int nRuns = 30;
INDArray[] inputArrs = new INDArray[nThreads];
INDArray[] expOut = new INDArray[nThreads];
for( int i=0; i<nThreads; i++ ){
if(i == 0 || i > 2)
inputArrs[i] = Nd4j.rand(DataType.FLOAT, 1, 224, 224, 3);
else if(i == 1)
inputArrs[i] = Nd4j.zeros(DataType.FLOAT, 1, 224, 224, 3);
else if(i == 2)
inputArrs[i] = Nd4j.ones(DataType.FLOAT, 1, 224, 224, 3);
expOut[i] = sd.outputSingle(Collections.singletonMap("input", inputArrs[i]), "MobilenetV2/Predictions/Reshape_1");
Nd4j.getExecutioner().commit();
}
AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads];
AtomicInteger[] counters = new AtomicInteger[nThreads];
Semaphore s = new Semaphore(nThreads);
CountDownLatch latch = new CountDownLatch(nThreads);
doTest(sd, nThreads, nRuns, inputArrs, expOut, "input", "MobilenetV2/Predictions/Reshape_1", failuresByThread, counters, s, latch);
s.release(nThreads);
latch.await();
for(int i = 0; i < nThreads; i++) {
assertFalse( failuresByThread[i].get(),"Thread " + i + " failed");
}
for(int i = 0; i < nThreads; i++) {
assertEquals( nRuns, counters[i].get(),"Thread " + i + " number of runs");
}
}
public static void doTest(SameDiff sd, int nThreads, int nRuns, INDArray[] inputArrs, INDArray[] expOut, public static void doTest(SameDiff sd, int nThreads, int nRuns, INDArray[] inputArrs, INDArray[] expOut,

View File

@ -21,7 +21,9 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -31,14 +33,13 @@ import org.nd4j.linalg.learning.config.Sgd;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class SameDiffOutputTest extends BaseNd4jTest { public class SameDiffOutputTest extends BaseNd4jTestWithBackends {
public SameDiffOutputTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void outputTest(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void outputTest(Nd4jBackend backend){
DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10)); DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10));
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();

View File

@ -21,7 +21,9 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -35,19 +37,18 @@ import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.assertNull; import static junit.framework.TestCase.assertNull;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest { public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
public SameDiffSpecifiedLossVarsTests(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering() {
return 'c'; return 'c';
} }
@Test @Test
public void testSpecifiedLoss1(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedLoss1(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4); SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4);
ph1.setArray(Nd4j.create(DataType.FLOAT, 3, 4)); ph1.setArray(Nd4j.create(DataType.FLOAT, 3, 4));
@ -68,7 +69,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
} }
@Test @Test
public void testSpecifiedLoss2(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedLoss2(Nd4jBackend backend) {
for( int i=0; i<2; i++ ) { for( int i=0; i<2; i++ ) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4); SDVariable ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4);
@ -121,7 +124,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
@Test @Test
public void testTrainingDifferentLosses(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingDifferentLosses(Nd4jBackend backend) {
//Net with 2 losses: train on the first one, then change losses //Net with 2 losses: train on the first one, then change losses
//Also check that if modifying via add/setLossVariables the training config changes //Also check that if modifying via add/setLossVariables the training config changes

View File

@ -30,11 +30,13 @@ import java.util.Map;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.impl.ScoreListener; import org.nd4j.autodiff.listeners.impl.ScoreListener;
import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -55,14 +57,13 @@ import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.weightinit.impl.XavierInitScheme; import org.nd4j.weightinit.impl.XavierInitScheme;
@Slf4j @Slf4j
public class SameDiffTrainingTest extends BaseNd4jTest { public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
public SameDiffTrainingTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void irisTrainingSanityCheck() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingSanityCheck(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize(); NormalizerStandardize std = new NormalizerStandardize();
@ -134,7 +135,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test @Test
public void irisTrainingEvalTest() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingEvalTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize(); NormalizerStandardize std = new NormalizerStandardize();
@ -184,7 +187,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test @Test
public void irisTrainingValidationTest() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingValidationTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize(); NormalizerStandardize std = new NormalizerStandardize();
@ -239,6 +244,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingMixedDtypes(){ public void testTrainingMixedDtypes(){
for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) { for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) {
@ -301,7 +308,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
} }
@Test @Test
public void simpleClassification() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void simpleClassification(Nd4jBackend backend) {
double learning_rate = 0.001; double learning_rate = 0.001;
int seed = 7; int seed = 7;
org.nd4j.linalg.api.rng.Random rng = Nd4j.getRandom(); org.nd4j.linalg.api.rng.Random rng = Nd4j.getRandom();
@ -348,6 +357,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingEvalVarNotReqForLoss(){ public void testTrainingEvalVarNotReqForLoss(){
//If a variable is not required for the loss - normally it won't be calculated //If a variable is not required for the loss - normally it won't be calculated
//But we want to make sure it IS calculated here - so we can perform evaluation on it //But we want to make sure it IS calculated here - so we can perform evaluation on it

View File

@ -25,11 +25,13 @@ import org.junit.Assert;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener; import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.dataset.IrisDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -48,11 +50,8 @@ import java.util.concurrent.TimeUnit;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class CheckpointListenerTest extends BaseNd4jTest { public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
public CheckpointListenerTest(Nd4jBackend backend){
super(backend);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -96,7 +95,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
@Test @Test
public void testCheckpointEveryEpoch(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();
@ -130,7 +131,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void testCheckpointEvery5Iter(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();
@ -169,7 +172,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
@Test @Test
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();
@ -199,7 +204,7 @@ public class CheckpointListenerTest extends BaseNd4jTest {
for(File f : files){ for(File f : files){
String s = f.getAbsolutePath(); String s = f.getAbsolutePath();
// System.out.println(s); // System.out.println(s);
for( int i=0; i<names.size(); i++ ){ for( int i = 0; i < names.size(); i++ ){
if(s.contains(names.get(i))){ if(s.contains(names.get(i))){
found[i] = true; found[i] = true;
break; break;
@ -213,7 +218,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();

View File

@ -21,12 +21,13 @@
package org.nd4j.autodiff.samediff.listeners; package org.nd4j.autodiff.samediff.listeners;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener; import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -34,14 +35,13 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
public class ExecDebuggingListenerTest extends BaseNd4jTest { public class ExecDebuggingListenerTest extends BaseNd4jTestWithBackends {
public ExecDebuggingListenerTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testExecDebugListener(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecDebugListener(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);

View File

@ -21,6 +21,8 @@
package org.nd4j.autodiff.samediff.listeners; package org.nd4j.autodiff.samediff.listeners;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener; import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.listeners.Listener;
@ -38,7 +40,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.Evaluation.Metric; import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.OpContext;
@ -61,11 +63,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class ListenerTest extends BaseNd4jTest { public class ListenerTest extends BaseNd4jTestWithBackends {
public ListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -73,7 +72,9 @@ public class ListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void irisHistoryTest() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisHistoryTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize(); NormalizerStandardize std = new NormalizerStandardize();
@ -136,6 +137,8 @@ public class ListenerTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testListenerCalls(){ public void testListenerCalls(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
@ -273,7 +276,9 @@ public class ListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void testCustomListener() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCustomListener(Nd4jBackend backend) {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4); SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);

View File

@ -26,11 +26,13 @@ import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.profiler.ProfilingListener; import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer; import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -45,11 +47,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
public class ProfilingListenerTest extends BaseNd4jTest { public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
public ProfilingListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -59,7 +58,9 @@ public class ProfilingListenerTest extends BaseNd4jTest {
@Test @Test
public void testProfilingListenerSimple(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
@ -107,19 +108,25 @@ public class ProfilingListenerTest extends BaseNd4jTest {
} }
/* /*
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfile(){ public void testLoadTfProfile(){
File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json"); File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json");
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfileDir(){ public void testLoadTfProfileDir(){
File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles"); File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfileDir2(){ public void testLoadTfProfileDir2(){
File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0"); File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW); ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);

View File

@ -27,6 +27,8 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.VariableType;
@ -41,7 +43,7 @@ import org.nd4j.graph.UIInfoType;
import org.nd4j.graph.UIOp; import org.nd4j.graph.UIOp;
import org.nd4j.graph.UIVariable; import org.nd4j.graph.UIVariable;
import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -60,11 +62,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
public class FileReadWriteTests extends BaseNd4jTest { public class FileReadWriteTests extends BaseNd4jTestWithBackends {
public FileReadWriteTests(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -81,7 +80,9 @@ public class FileReadWriteTests extends BaseNd4jTest {
} }
@Test @Test
public void testSimple(@TempDir Path testDir) throws IOException { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4); SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
SDVariable sum = v.sum(); SDVariable sum = v.sum();
@ -185,7 +186,9 @@ public class FileReadWriteTests extends BaseNd4jTest {
} }
@Test @Test
public void testNullBinLabels(@TempDir Path testDir) throws Exception{ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
File dir = testDir.toFile(); File dir = testDir.toFile();
File f = new File(dir, "temp.bin"); File f = new File(dir, "temp.bin");
LogFileWriter w = new LogFileWriter(f); LogFileWriter w = new LogFileWriter(f);

View File

@ -25,6 +25,8 @@ import com.google.flatbuffers.Table;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.impl.UIListener; import org.nd4j.autodiff.listeners.impl.UIListener;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -34,7 +36,7 @@ import org.nd4j.graph.UIEvent;
import org.nd4j.graph.UIGraphStructure; import org.nd4j.graph.UIGraphStructure;
import org.nd4j.graph.UIStaticInfoRecord; import org.nd4j.graph.UIStaticInfoRecord;
import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.dataset.IrisDataSetIterator;
@ -51,11 +53,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class UIListenerTest extends BaseNd4jTest { public class UIListenerTest extends BaseNd4jTestWithBackends {
public UIListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -65,7 +64,9 @@ public class UIListenerTest extends BaseNd4jTest {
@Test @Test
public void testUIListenerBasic(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -101,7 +102,9 @@ public class UIListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void testUIListenerContinue(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet(); SameDiff sd1 = getSimpleNet();
@ -192,7 +195,9 @@ public class UIListenerTest extends BaseNd4jTest {
} }
@Test @Test
public void testUIListenerBadContinue(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet(); SameDiff sd1 = getSimpleNet();

View File

@ -23,18 +23,17 @@ package org.nd4j.evaluation;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.custom.CustomEvaluation; import org.nd4j.evaluation.custom.CustomEvaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
public class CustomEvaluationTest extends BaseNd4jTest { public class CustomEvaluationTest extends BaseNd4jTestWithBackends {
public CustomEvaluationTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -42,8 +41,10 @@ public class CustomEvaluationTest extends BaseNd4jTest {
} }
@Test @Test
public void customEvalTest(){ @ParameterizedTest
CustomEvaluation accuracyEval = new CustomEvaluation<Pair<Number, Long>>( @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void customEvalTest(Nd4jBackend backend){
CustomEvaluation accuracyEval = new CustomEvaluation<>(
(labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)), (labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)),
CustomEvaluation.mergeConcatenate()); CustomEvaluation.mergeConcatenate());

View File

@ -21,6 +21,8 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -29,25 +31,24 @@ import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
public class EmptyEvaluationTests extends BaseNd4jTest { public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
public EmptyEvaluationTests(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';
} }
@Test @Test
public void testEmptyEvaluation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluation (Nd4jBackend backend) {
Evaluation e = new Evaluation(); Evaluation e = new Evaluation();
System.out.println(e.stats()); System.out.println(e.stats());
@ -62,7 +63,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyRegressionEvaluation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyRegressionEvaluation (Nd4jBackend backend) {
RegressionEvaluation re = new RegressionEvaluation(); RegressionEvaluation re = new RegressionEvaluation();
re.stats(); re.stats();
@ -76,7 +79,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyEvaluationBinary() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluationBinary(Nd4jBackend backend) {
EvaluationBinary eb = new EvaluationBinary(); EvaluationBinary eb = new EvaluationBinary();
eb.stats(); eb.stats();
@ -91,7 +96,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyROC() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROC(Nd4jBackend backend) {
ROC roc = new ROC(); ROC roc = new ROC();
roc.stats(); roc.stats();
@ -106,7 +113,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyROCBinary() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROCBinary(Nd4jBackend backend) {
ROCBinary rb = new ROCBinary(); ROCBinary rb = new ROCBinary();
rb.stats(); rb.stats();
@ -121,7 +130,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyROCMultiClass() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROCMultiClass(Nd4jBackend backend) {
ROCMultiClass r = new ROCMultiClass(); ROCMultiClass r = new ROCMultiClass();
r.stats(); r.stats();
@ -136,7 +147,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
} }
@Test @Test
public void testEmptyEvaluationCalibration() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluationCalibration(Nd4jBackend backend) {
EvaluationCalibration ec = new EvaluationCalibration(); EvaluationCalibration ec = new EvaluationCalibration();
ec.stats(); ec.stats();

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin; import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
@ -36,11 +38,8 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class EvalCustomThreshold extends BaseNd4jTest { public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
public EvalCustomThreshold(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -48,7 +47,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationCustomBinaryThreshold() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCustomBinaryThreshold(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
//Sanity checks: 0.5 threshold for 1-output and 2-output binary cases //Sanity checks: 0.5 threshold for 1-output and 2-output binary cases
@ -114,7 +115,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationCostArray() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCostArray(Nd4jBackend backend) {
int nExamples = 20; int nExamples = 20;
@ -162,7 +165,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinaryCustomThreshold() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryCustomThreshold(Nd4jBackend backend) {
//Sanity check: same results for 0.5 threshold vs. default (no threshold) //Sanity check: same results for 0.5 threshold vs. default (no threshold)
int nExamples = 20; int nExamples = 20;

View File

@ -21,6 +21,8 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -31,7 +33,7 @@ import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class EvalJsonTest extends BaseNd4jTest { public class EvalJsonTest extends BaseNd4jTestWithBackends {
public EvalJsonTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -54,7 +53,9 @@ public class EvalJsonTest extends BaseNd4jTest {
} }
@Test @Test
public void testSerdeEmpty() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerdeEmpty(Nd4jBackend backend) {
boolean print = false; boolean print = false;
IEvaluation[] arr = new IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10), IEvaluation[] arr = new IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10),
@ -73,8 +74,10 @@ public class EvalJsonTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testSerde() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerde(Nd4jBackend backend) {
boolean print = false; boolean print = false;
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -121,8 +124,10 @@ public class EvalJsonTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testSerdeExactRoc() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerdeExactRoc(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
boolean print = false; boolean print = false;
@ -199,8 +204,10 @@ public class EvalJsonTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testJsonYamlCurves() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJsonYamlCurves(Nd4jBackend backend) {
ROC roc = new ROC(0); ROC roc = new ROC(0);
INDArray evalLabel = INDArray evalLabel =
@ -251,8 +258,10 @@ public class EvalJsonTest extends BaseNd4jTest {
} }
@Test @Test
public void testJsonWithCustomThreshold() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJsonWithCustomThreshold(Nd4jBackend backend) {
//Evaluation - binary threshold //Evaluation - binary threshold
Evaluation e = new Evaluation(0.25); Evaluation e = new Evaluation(0.25);

View File

@ -21,8 +21,10 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -39,11 +41,8 @@ import static org.junit.jupiter.api.Assertions.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
public class EvalTest extends BaseNd4jTest { public class EvalTest extends BaseNd4jTestWithBackends {
public EvalTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -52,7 +51,9 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
public void testEval() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEval(Nd4jBackend backend) {
int classNum = 5; int classNum = 5;
Evaluation eval = new Evaluation (classNum); Evaluation eval = new Evaluation (classNum);
@ -91,7 +92,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testEval2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEval2(Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
Evaluation first = null; Evaluation first = null;
@ -150,7 +153,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testStringListLabels() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStringListLabels(Nd4jBackend backend) {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
@ -167,7 +172,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testStringHashLabels() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStringHashLabels(Nd4jBackend backend) {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
@ -184,7 +191,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvalMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalMasking(Nd4jBackend backend) {
int miniBatch = 5; int miniBatch = 5;
int nOut = 3; int nOut = 3;
int tsLength = 6; int tsLength = 6;
@ -251,7 +260,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testFalsePerfectRecall() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFalsePerfectRecall(Nd4jBackend backend) {
int testSize = 100; int testSize = 100;
int numClasses = 5; int numClasses = 5;
int winner = 1; int winner = 1;
@ -284,7 +295,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationMerging(Nd4jBackend backend) {
int nRows = 20; int nRows = 20;
int nCols = 3; int nCols = 3;
@ -358,7 +371,9 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
public void testSingleClassBinaryClassification() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleClassBinaryClassification(Nd4jBackend backend) {
Evaluation eval = new Evaluation(1); Evaluation eval = new Evaluation(1);
@ -387,7 +402,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvalInvalid() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalInvalid(Nd4jBackend backend) {
Evaluation e = new Evaluation(5); Evaluation e = new Evaluation(5);
e.eval(0, 1); e.eval(0, 1);
e.eval(1, 0); e.eval(1, 0);
@ -400,7 +417,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvalMethods() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalMethods(Nd4jBackend backend) {
//Check eval(int,int) vs. eval(INDArray,INDArray) //Check eval(int,int) vs. eval(INDArray,INDArray)
Evaluation e1 = new Evaluation(4); Evaluation e1 = new Evaluation(4);
@ -443,7 +462,9 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
public void testTopNAccuracy() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopNAccuracy(Nd4jBackend backend) {
Evaluation e = new Evaluation(null, 3); Evaluation e = new Evaluation(null, 3);
@ -504,7 +525,9 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
public void testTopNAccuracyMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopNAccuracyMerging(Nd4jBackend backend) {
Evaluation e1 = new Evaluation(null, 3); Evaluation e1 = new Evaluation(null, 3);
Evaluation e2 = new Evaluation(null, 3); Evaluation e2 = new Evaluation(null, 3);
@ -552,7 +575,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testBinaryCase() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBinaryCase(Nd4jBackend backend) {
INDArray ones10 = Nd4j.ones(10, 1); INDArray ones10 = Nd4j.ones(10, 1);
INDArray ones4 = Nd4j.ones(4, 1); INDArray ones4 = Nd4j.ones(4, 1);
INDArray zeros4 = Nd4j.zeros(4, 1); INDArray zeros4 = Nd4j.zeros(4, 1);
@ -581,7 +606,9 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testF1FBeta_MicroMacroAveraging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testF1FBeta_MicroMacroAveraging(Nd4jBackend backend) {
//Confusion matrix: rows = actual, columns = predicted //Confusion matrix: rows = actual, columns = predicted
//[3, 1, 0] //[3, 1, 0]
//[2, 2, 1] //[2, 2, 1]
@ -722,7 +749,9 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
public void testConfusionMatrixStats() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrixStats(Nd4jBackend backend) {
Evaluation e = new Evaluation(); Evaluation e = new Evaluation();
@ -743,6 +772,8 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalBinaryMetrics(){ public void testEvalBinaryMetrics(){
Evaluation ePosClass1_nOut2 = new Evaluation(2, 1); Evaluation ePosClass1_nOut2 = new Evaluation(2, 1);
@ -864,6 +895,8 @@ public class EvalTest extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrixString(){ public void testConfusionMatrixString(){
Evaluation e = new Evaluation(Arrays.asList("a","b","c")); Evaluation e = new Evaluation(Arrays.asList("a","b","c"));
@ -914,6 +947,8 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationNaNs(){ public void testEvaluationNaNs(){
Evaluation e = new Evaluation(); Evaluation e = new Evaluation();
@ -929,6 +964,8 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){ public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1023,6 +1060,8 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLabelReset(){ public void testLabelReset(){
Map<Integer,String> m = new HashMap<>(); Map<Integer,String> m = new HashMap<>();
@ -1056,6 +1095,8 @@ public class EvalTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalStatsBinaryCase(){ public void testEvalStatsBinaryCase(){
//Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case //Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,11 +40,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.nd4j.evaluation.classification.EvaluationBinary.Metric.*; import static org.nd4j.evaluation.classification.EvaluationBinary.Metric.*;
public class EvaluationBinaryTest extends BaseNd4jTest { public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
public EvaluationBinaryTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -50,7 +49,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinary() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary(Nd4jBackend backend) {
//Compare EvaluationBinary to Evaluation class //Compare EvaluationBinary to Evaluation class
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
EvaluationBinary first = null; EvaluationBinary first = null;
@ -136,7 +137,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinaryMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryMerging(Nd4jBackend backend) {
int nOut = 4; int nOut = 4;
int[] shape1 = {30, nOut}; int[] shape1 = {30, nOut};
int[] shape2 = {50, nOut}; int[] shape2 = {50, nOut};
@ -163,7 +166,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinaryPerOutputMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryPerOutputMasking(Nd4jBackend backend) {
//Provide a mask array: "ignore" the masked steps //Provide a mask array: "ignore" the masked steps
@ -172,7 +177,7 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
INDArray labels = Nd4j.create(new double[][] {{1, 1, 1}, {0, 0, 0}, {1, 1, 1}, {0, 1, 1}, {1, 0, 1}}); INDArray labels = Nd4j.create(new double[][] {{1, 1, 1}, {0, 0, 0}, {1, 1, 1}, {0, 1, 1}, {1, 0, 1}});
INDArray predicted = Nd4j.create(new double[][] {{0.9, 0.9, 0.9}, {0.7, 0.7, 0.7}, {0.6, 0.6, 0.6}, INDArray predicted = Nd4j.create(new double[][] {{0.9, 0.9, 0.9}, {0.7, 0.7, 0.7}, {0.6, 0.6, 0.6},
{0.4, 0.4, 0.4}, {0.1, 0.1, 0.1}}); {0.4, 0.4, 0.4}, {0.1, 0.1, 0.1}});
//Correct? //Correct?
// Y Y m // Y Y m
@ -206,7 +211,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testTimeSeriesEval() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTimeSeriesEval(Nd4jBackend backend) {
int[] shape = {2, 4, 3}; int[] shape = {2, 4, 3};
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -230,12 +237,14 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinaryWithROC() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryWithROC(Nd4jBackend backend) {
//Simple test for nested ROCBinary in EvaluationBinary //Simple test for nested ROCBinary in EvaluationBinary
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray l1 = Nd4j.getExecutioner() INDArray l1 = Nd4j.getExecutioner()
.exec(new BernoulliDistribution(Nd4j.createUninitialized(new int[] {50, 4}), 0.5)); .exec(new BernoulliDistribution(Nd4j.createUninitialized(new int[] {50, 4}), 0.5));
INDArray p1 = Nd4j.rand(50, 4); INDArray p1 = Nd4j.rand(50, 4);
EvaluationBinary eb = new EvaluationBinary(4, 30); EvaluationBinary eb = new EvaluationBinary(4, 30);
@ -247,7 +256,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
@Test @Test
public void testEvaluationBinary3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -281,7 +292,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinary4d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -315,7 +328,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinary3dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -376,7 +391,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testEvaluationBinary4dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -21,8 +21,10 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,19 +41,18 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class EvaluationCalibrationTest extends BaseNd4jTest { public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
public EvaluationCalibrationTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering () {
return 'c'; return 'c';
} }
@Test @Test
public void testReliabilityDiagram() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReliabilityDiagram (Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
EvaluationCalibration first = null; EvaluationCalibration first = null;
@ -142,8 +143,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testLabelAndPredictionCounts() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLabelAndPredictionCounts (Nd4jBackend backend) {
int minibatch = 50; int minibatch = 50;
int nClasses = 3; int nClasses = 3;
@ -170,8 +173,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
assertArrayEquals(expPredictionCount, ec.getPredictionCountsEachClass()); assertArrayEquals(expPredictionCount, ec.getPredictionCountsEachClass());
} }
@Test @Test
public void testResidualPlots() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResidualPlots (Nd4jBackend backend) {
int minibatch = 50; int minibatch = 50;
int nClasses = 3; int nClasses = 3;
@ -271,7 +276,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
} }
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){ public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -365,8 +372,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testEvaluationCalibration3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCalibration3d (Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -397,8 +406,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
assertEquals(e2d.stats(), e3d.stats()); assertEquals(e2d.stats(), e3d.stats());
} }
@Test @Test
public void testEvaluationCalibration3dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCalibration3dMasking (Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);

View File

@ -23,6 +23,8 @@ package org.nd4j.evaluation;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -30,17 +32,14 @@ import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
public class NewInstanceTest extends BaseNd4jTest { public class NewInstanceTest extends BaseNd4jTestWithBackends {
public NewInstanceTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -48,7 +47,9 @@ public class NewInstanceTest extends BaseNd4jTest {
} }
@Test @Test
public void testNewInstances() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewInstances(Nd4jBackend backend) {
boolean print = true; boolean print = true;
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -21,10 +21,12 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,19 +41,17 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class ROCBinaryTest extends BaseNd4jTest { public class ROCBinaryTest extends BaseNd4jTestWithBackends {
public ROCBinaryTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';
} }
@Test @Test
public void testROCBinary() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary(Nd4jBackend backend) {
//Compare ROCBinary to ROC class //Compare ROCBinary to ROC class
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
@ -145,8 +145,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testRocBinaryMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBinaryMerging(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact for (int nSteps : new int[]{30, 0}) { //0 == exact
int nOut = 4; int nOut = 4;
int[] shape1 = {30, nOut}; int[] shape1 = {30, nOut};
@ -175,8 +177,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
} }
@Test @Test
public void testROCBinaryPerOutputMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact for (int nSteps : new int[]{30, 0}) { //0 == exact
@ -215,8 +219,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
@Test @Test
public void testROCBinary3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -249,8 +255,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testROCBinary4d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -283,8 +291,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testROCBinary3dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -344,8 +354,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testROCBinary4dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -21,12 +21,14 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -39,11 +41,8 @@ import java.util.*;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class ROCTest extends BaseNd4jTest { public class ROCTest extends BaseNd4jTestWithBackends {
public ROCTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -83,8 +82,10 @@ public class ROCTest extends BaseNd4jTest {
expFPR.put(10 / 10.0, 0.0 / totalNegatives); expFPR.put(10 / 10.0, 0.0 / totalNegatives);
} }
@Test @Test
public void testRocBasic() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBasic(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax) //2 outputs here - probability distribution over classes (softmax)
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601}, {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
@ -126,8 +127,10 @@ public class ROCTest extends BaseNd4jTest {
assertEquals(1.0, auc, 1e-6); assertEquals(1.0, auc, 1e-6);
} }
@Test @Test
public void testRocBasicSingleClass() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBasicSingleClass(Nd4jBackend backend) {
//1 output here - single probability value (sigmoid) //1 output here - single probability value (sigmoid)
//add 0.001 to avoid numerical/rounding issues (float vs. double, etc) //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
@ -164,8 +167,10 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testRoc() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRoc(Nd4jBackend backend) {
//Previous tests allowed for a perfect classifier with right threshold... //Previous tests allowed for a perfect classifier with right threshold...
INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}}); INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});
@ -249,8 +254,10 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testRocTimeSeriesNoMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
//Same as first test... //Same as first test...
//2 outputs here - probability distribution over classes (softmax) //2 outputs here - probability distribution over classes (softmax)
@ -296,8 +303,10 @@ public class ROCTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testRocTimeSeriesMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax) //2 outputs here - probability distribution over classes (softmax)
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601}, {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
@ -346,8 +355,10 @@ public class ROCTest extends BaseNd4jTest {
@Test @Test
public void testCompareRocAndRocMultiClass() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
//For 2 class case: ROC and Multi-class ROC should be the same... //For 2 class case: ROC and Multi-class ROC should be the same...
@ -376,8 +387,10 @@ public class ROCTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testCompare2Vs3Classes() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCompare2Vs3Classes(Nd4jBackend backend) {
//ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together... //ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together...
//Both methods implement one vs. all ROC/AUC in different ways //Both methods implement one vs. all ROC/AUC in different ways
@ -425,8 +438,10 @@ public class ROCTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testROCMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMerging(Nd4jBackend backend) {
int nArrays = 10; int nArrays = 10;
int minibatch = 64; int minibatch = 64;
int nROCs = 3; int nROCs = 3;
@ -470,8 +485,10 @@ public class ROCTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testROCMerging2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMerging2(Nd4jBackend backend) {
int nArrays = 10; int nArrays = 10;
int minibatch = 64; int minibatch = 64;
int exactAllocBlockSize = 10; int exactAllocBlockSize = 10;
@ -515,8 +532,10 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testROCMultiMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMultiMerging(Nd4jBackend backend) {
int nArrays = 10; int nArrays = 10;
int minibatch = 64; int minibatch = 64;
@ -563,8 +582,10 @@ public class ROCTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testAUCPrecisionRecall() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAUCPrecisionRecall(Nd4jBackend backend) {
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob //Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
//at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0 //at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0
//at threshold 0.25 to 0.33: tp=2, fp=0, fn=0, tn=1 prec=2/2=1, recall=2/2=1 //at threshold 0.25 to 0.33: tp=2, fp=0, fn=0, tn=1 prec=2/2=1, recall=2/2=1
@ -610,8 +631,10 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testRocAucExact() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocAucExact(Nd4jBackend backend) {
//Check the implementation vs. Scikitlearn //Check the implementation vs. Scikitlearn
/* /*
@ -773,8 +796,10 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void rocExactEdgeCaseReallocation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
//Set reallocation block size to say 20, but then evaluate a 100-length array //Set reallocation block size to say 20, but then evaluate a 100-length array
@ -785,8 +810,10 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
public void testPrecisionRecallCurveGetPointMethods() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
double[] threshold = new double[101]; double[] threshold = new double[101];
double[] precision = threshold; double[] precision = threshold;
double[] recall = new double[101]; double[] recall = new double[101];
@ -821,8 +848,10 @@ public class ROCTest extends BaseNd4jTest {
} }
} }
@Test @Test
public void testPrecisionRecallCurveConfusion() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
//Sanity check: values calculated from the confusion matrix should match the PR curve values //Sanity check: values calculated from the confusion matrix should match the PR curve values
for (boolean removeRedundantPts : new boolean[] {true, false}) { for (boolean removeRedundantPts : new boolean[] {true, false}) {
@ -860,7 +889,9 @@ public class ROCTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocMerge(){ public void testRocMerge(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -904,7 +935,9 @@ public class ROCTest extends BaseNd4jTest {
assertEquals(auprc, auprcAct, 1e-6); assertEquals(auprc, auprcAct, 1e-6);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocMultiMerge(){ public void testRocMultiMerge(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -953,7 +986,9 @@ public class ROCTest extends BaseNd4jTest {
} }
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBinaryMerge(){ public void testRocBinaryMerge(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -998,7 +1033,9 @@ public class ROCTest extends BaseNd4jTest {
} }
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentationBinary(){ public void testSegmentationBinary(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1088,7 +1125,9 @@ public class ROCTest extends BaseNd4jTest {
} }
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){ public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
public class RegressionEvalTest extends BaseNd4jTest { public class RegressionEvalTest extends BaseNd4jTestWithBackends {
public RegressionEvalTest(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -52,7 +51,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test() @Test()
public void testEvalParameters() { public void testEvalParameters(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
int specCols = 5; int specCols = 5;
INDArray labels = Nd4j.ones(3); INDArray labels = Nd4j.ones(3);
@ -65,7 +64,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testPerfectPredictions() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPerfectPredictions(Nd4jBackend backend) {
int nCols = 5; int nCols = 5;
int nTestArrays = 100; int nTestArrays = 100;
@ -92,7 +93,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testKnownValues() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testKnownValues(Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType(); DataType dtypeBefore = Nd4j.defaultFloatingPointType();
RegressionEvaluation first = null; RegressionEvaluation first = null;
@ -148,7 +151,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
@Test @Test
public void testRegressionEvaluationMerging() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvaluationMerging(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
int nRows = 20; int nRows = 20;
@ -189,7 +194,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testRegressionEvalPerOutputMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvalPerOutputMasking(Nd4jBackend backend) {
INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}}); INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}});
@ -216,6 +223,8 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvalTimeSeriesSplit(){ public void testRegressionEvalTimeSeriesSplit(){
INDArray out1 = Nd4j.rand(new int[]{3, 5, 20}); INDArray out1 = Nd4j.rand(new int[]{3, 5, 20});
@ -238,7 +247,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testRegressionEval3d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -270,7 +281,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testRegressionEval4d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -302,7 +315,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testRegressionEval3dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -361,7 +376,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
} }
@Test @Test
public void testRegressionEval4dMasking() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -22,10 +22,12 @@ package org.nd4j.evaluation;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -34,11 +36,8 @@ import java.nio.charset.StandardCharsets;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestLegacyJsonLoading extends BaseNd4jTest { public class TestLegacyJsonLoading extends BaseNd4jTestWithBackends {
public TestLegacyJsonLoading(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -46,7 +45,9 @@ public class TestLegacyJsonLoading extends BaseNd4jTest {
} }
@Test @Test
public void testEvalLegacyFormat() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalLegacyFormat(Nd4jBackend backend) throws Exception {
File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile(); File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile();
String s = FileUtils.readFileToString(f, StandardCharsets.UTF_8); String s = FileUtils.readFileToString(f, StandardCharsets.UTF_8);

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,17 +39,14 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class AveragingTests extends BaseNd4jTest { public class AveragingTests extends BaseNd4jTestWithBackends {
private final int THREADS = 16; private final int THREADS = 16;
private final int LENGTH = 51200 * 4; private final int LENGTH = 51200 * 4;
DataType initialType; DataType initialType = Nd4j.dataType();
public AveragingTests(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
public void setUp() { public void setUp() {
@ -63,7 +61,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test @Test
public void testSingleDeviceAveraging1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleDeviceAveraging1(Nd4jBackend backend) {
INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0); INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0);
INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0); INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0);
INDArray array3 = Nd4j.valueArrayOf(LENGTH, 3.0); INDArray array3 = Nd4j.valueArrayOf(LENGTH, 3.0);
@ -110,7 +110,9 @@ public class AveragingTests extends BaseNd4jTest {
} }
@Test @Test
public void testSingleDeviceAveraging2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleDeviceAveraging2(Nd4jBackend backend) {
INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH); INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH);
List<INDArray> arrays = new ArrayList<>(); List<INDArray> arrays = new ArrayList<>();
for (int i = 0; i < THREADS; i++) for (int i = 0; i < THREADS; i++)
@ -127,7 +129,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test @Test
public void testAccumulation1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation1(Nd4jBackend backend) {
INDArray array1 = Nd4j.create(100).assign(1.0); INDArray array1 = Nd4j.create(100).assign(1.0);
INDArray array2 = Nd4j.create(100).assign(2.0); INDArray array2 = Nd4j.create(100).assign(2.0);
INDArray array3 = Nd4j.create(100).assign(3.0); INDArray array3 = Nd4j.create(100).assign(3.0);
@ -140,7 +144,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test @Test
public void testAccumulation2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation2(Nd4jBackend backend) {
INDArray array1 = Nd4j.create(100).assign(1.0); INDArray array1 = Nd4j.create(100).assign(1.0);
INDArray array2 = Nd4j.create(100).assign(2.0); INDArray array2 = Nd4j.create(100).assign(2.0);
INDArray array3 = Nd4j.create(100).assign(3.0); INDArray array3 = Nd4j.create(100).assign(3.0);
@ -155,7 +161,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test @Test
public void testAccumulation3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation3(Nd4jBackend backend) {
// we want to ensure that cuda backend is able to launch this op on cpu // we want to ensure that cuda backend is able to launch this op on cpu
Nd4j.getAffinityManager().allowCrossDeviceAccess(false); Nd4j.getAffinityManager().allowCrossDeviceAccess(false);

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -34,15 +35,14 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@Slf4j @Slf4j
public class DataTypeTest extends BaseNd4jTest { public class DataTypeTest extends BaseNd4jTestWithBackends {
public DataTypeTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testDataTypes() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDataTypes(Nd4jBackend backend) throws Exception {
for (val type : DataType.values()) { for (val type : DataType.values()) {
if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type)) if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type))
continue; continue;

View File

@ -21,20 +21,17 @@
package org.nd4j.linalg; package org.nd4j.linalg;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@RunWith(Parameterized.class)
public class InputValidationTests extends BaseNd4jTest {
public InputValidationTests(Nd4jBackend backend) { public class InputValidationTests extends BaseNd4jTestWithBackends {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -45,7 +42,9 @@ public class InputValidationTests extends BaseNd4jTest {
///////////////////// Broadcast Tests /////////////////////// ///////////////////// Broadcast Tests ///////////////////////
@Test @Test
public void testInvalidColVectorOp1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidColVectorOp1(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray col = Nd4j.create(5, 1); INDArray col = Nd4j.create(5, 1);
try { try {
@ -57,7 +56,9 @@ public class InputValidationTests extends BaseNd4jTest {
} }
@Test @Test
public void testInvalidColVectorOp2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidColVectorOp2(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray col = Nd4j.create(5, 1); INDArray col = Nd4j.create(5, 1);
try { try {
@ -69,7 +70,9 @@ public class InputValidationTests extends BaseNd4jTest {
} }
@Test @Test
public void testInvalidRowVectorOp1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidRowVectorOp1(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray row = Nd4j.create(1, 5); INDArray row = Nd4j.create(1, 5);
try { try {
@ -81,7 +84,9 @@ public class InputValidationTests extends BaseNd4jTest {
} }
@Test @Test
public void testInvalidRowVectorOp2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidRowVectorOp2(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10); INDArray first = Nd4j.create(10, 10);
INDArray row = Nd4j.create(1, 5); INDArray row = Nd4j.create(1, 5);
try { try {

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.lang3.RandomUtils; import org.apache.commons.lang3.RandomUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
@ -47,14 +48,13 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class LoneTest extends BaseNd4jTest { public class LoneTest extends BaseNd4jTestWithBackends {
public LoneTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testSoftmaxStability() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxStability(Nd4jBackend backend) {
INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose(); INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose();
// System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo())); // System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1); INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1);
@ -68,7 +68,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void testFlattenedView() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFlattenedView(Nd4jBackend backend) {
int rows = 8; int rows = 8;
int cols = 8; int cols = 8;
int dim2 = 4; int dim2 = 4;
@ -104,7 +106,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void testIndexingColVec() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingColVec(Nd4jBackend backend) {
int elements = 5; int elements = 5;
INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements); INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements);
INDArray colVector = rowVector.transpose(); INDArray colVector = rowVector.transpose();
@ -123,7 +127,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void concatScalarVectorIssue() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void concatScalarVectorIssue(Nd4jBackend backend) {
//A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars //A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars
INDArray arr1 = Nd4j.create(1, 1); INDArray arr1 = Nd4j.create(1, 1);
INDArray arr2 = Nd4j.create(1, 8); INDArray arr2 = Nd4j.create(1, 8);
@ -133,7 +139,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void reshapeTensorMmul() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void reshapeTensorMmul(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2); INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2);
INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2); INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2);
int[][] axes = new int[2][]; int[][] axes = new int[2][];
@ -145,7 +153,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void maskWhenMerge() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void maskWhenMerge(Nd4jBackend backend) {
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5)); DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3)); DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
List<DataSet> dataSetList = new ArrayList<DataSet>(); List<DataSet> dataSetList = new ArrayList<DataSet>();
@ -160,7 +170,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void testRelu() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRelu(Nd4jBackend backend) {
INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4); INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4); INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
INDArray b = Nd4j.getExecutioner().exec(new Tanh(aA)); INDArray b = Nd4j.getExecutioner().exec(new Tanh(aA));
@ -172,7 +184,7 @@ public class LoneTest extends BaseNd4jTest {
@Test @Test
//broken at a threshold //broken at a threshold
public void testArgMax() { public void testArgMax(Nd4jBackend backend) {
int max = 63; int max = 63;
INDArray A = Nd4j.linspace(1, max, max).reshape(1, max); INDArray A = Nd4j.linspace(1, max, max).reshape(1, max);
int currentArgMax = Nd4j.argMax(A).getInt(0); int currentArgMax = Nd4j.argMax(A).getInt(0);
@ -186,7 +198,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void testRPF() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRPF(Nd4jBackend backend) {
val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3); val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3);
log.info("--------"); log.info("--------");
@ -199,7 +213,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void testConcat3D_Vstack_C() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat3D_Vstack_C(Nd4jBackend backend) {
val shape = new long[]{1, 1000, 20}; val shape = new long[]{1, 1000, 20};
List<INDArray> cArrays = new ArrayList<>(); List<INDArray> cArrays = new ArrayList<>();
@ -229,7 +245,9 @@ public class LoneTest extends BaseNd4jTest {
@Test @Test
public void testGetRow1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRow1(Nd4jBackend backend) {
INDArray array = Nd4j.create(10000, 10000); INDArray array = Nd4j.create(10000, 10000);
//Thread.sleep(10000); //Thread.sleep(10000);
@ -256,7 +274,7 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test() @Test()
public void checkIllegalElementOps() { public void checkIllegalElementOps(Nd4jBackend backend) {
assertThrows(Exception.class,() -> { assertThrows(Exception.class,() -> {
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5); INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
INDArray B = A.dup().reshape(2, 2, 5); INDArray B = A.dup().reshape(2, 2, 5);
@ -268,7 +286,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void checkSliceofSlice() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void checkSliceofSlice(Nd4jBackend backend) {
/* /*
Issue 1: Slice of slice with c order and f order views are not equal Issue 1: Slice of slice with c order and f order views are not equal
@ -308,7 +328,9 @@ public class LoneTest extends BaseNd4jTest {
} }
@Test @Test
public void checkWithReshape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void checkWithReshape(Nd4jBackend backend) {
INDArray arr = Nd4j.create(1, 3); INDArray arr = Nd4j.create(1, 3);
INDArray reshaped = arr.reshape('f', 3, 1); INDArray reshaped = arr.reshape('f', 3, 1);
for (int i=0;i<reshaped.length();i++) { for (int i=0;i<reshaped.length();i++) {

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg; package org.nd4j.linalg;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -28,11 +30,8 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class MmulBug extends BaseNd4jTest { public class MmulBug extends BaseNd4jTestWithBackends {
public MmulBug(Nd4jBackend b){
super(b);
}
@Override @Override
public char ordering(){ public char ordering(){
@ -40,7 +39,9 @@ public class MmulBug extends BaseNd4jTest {
} }
@Test @Test
public void simpleTest() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void simpleTest(Nd4jBackend backend) {
INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}}); INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}});
m1 = m1.reshape(2, 2); m1 = m1.reshape(2, 2);

View File

@ -25,8 +25,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -58,19 +59,14 @@ import static org.junit.jupiter.api.Assertions.*;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
@Slf4j @Slf4j
public class NDArrayTestsFortran extends BaseNd4jTest { public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
public NDArrayTestsFortran(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testScalarOps() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarOps(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3}); INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
assertEquals(27d, n.length(), 1e-1); assertEquals(27d, n.length(), 1e-1);
n.addi(Nd4j.scalar(1d)); n.addi(Nd4j.scalar(1d));
@ -88,7 +84,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testColumnMmul() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumnMmul(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 10, 18, DataType.FLOAT).data(); DataBuffer data = Nd4j.linspace(1, 10, 18, DataType.FLOAT).data();
INDArray x2 = Nd4j.create(data, new long[] {2, 3, 3}); INDArray x2 = Nd4j.create(data, new long[] {2, 3, 3});
data = Nd4j.linspace(1, 12, 9, DataType.FLOAT).data(); data = Nd4j.linspace(1, 12, 9, DataType.FLOAT).data();
@ -119,7 +117,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testRowVectorGemm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowVectorGemm(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE);
INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE); INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE);
INDArray result = linspace.mmul(other); INDArray result = linspace.mmul(other);
@ -130,13 +130,17 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testRepmat() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRepmat(Nd4jBackend backend) {
INDArray rowVector = Nd4j.create(1, 4); INDArray rowVector = Nd4j.create(1, 4);
INDArray repmat = rowVector.repmat(4, 4); INDArray repmat = rowVector.repmat(4, 4);
assertTrue(Arrays.equals(new long[] {4, 16}, repmat.shape())); assertTrue(Arrays.equals(new long[] {4, 16}, repmat.shape()));
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWrite() throws Exception { public void testReadWrite() throws Exception {
INDArray write = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray write = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -152,6 +156,8 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWriteDouble() throws Exception { public void testReadWriteDouble() throws Exception {
INDArray write = Nd4j.linspace(1, 4, 4, DataType.FLOAT); INDArray write = Nd4j.linspace(1, 4, 4, DataType.FLOAT);
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -168,18 +174,17 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiThreading() throws Exception { public void testMultiThreading() throws Exception {
ExecutorService ex = ExecutorServiceProvider.getExecutorService(); ExecutorService ex = ExecutorServiceProvider.getExecutorService();
List<Future<?>> list = new ArrayList<>(100); List<Future<?>> list = new ArrayList<>(100);
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
Future<?> future = ex.submit(new Runnable() { Future<?> future = ex.submit(() -> {
@Override INDArray dot = Nd4j.linspace(1, 8, 8, DataType.DOUBLE);
public void run() {
INDArray dot = Nd4j.linspace(1, 8, 8, DataType.DOUBLE);
// System.out.println(Transforms.sigmoid(dot)); // System.out.println(Transforms.sigmoid(dot));
Transforms.sigmoid(dot); Transforms.sigmoid(dot);
}
}); });
list.add(future); list.add(future);
} }
@ -191,7 +196,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testBroadcastingGenerated() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastingGenerated(Nd4jBackend backend) {
int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10); int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10);
List<List<Pair<INDArray, String>>> broadCastList = new ArrayList<>(broadcastShape.length); List<List<Pair<INDArray, String>>> broadCastList = new ArrayList<>(broadcastShape.length);
for (int[] shape : broadcastShape) { for (int[] shape : broadcastShape) {
@ -206,7 +213,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
INDArray inputArrBroadcast = val.getFirst(); INDArray inputArrBroadcast = val.getFirst();
val destShape = NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7); val destShape = NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7);
INDArray output = inputArrBroadcast INDArray output = inputArrBroadcast
.broadcast(NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7)); .broadcast(NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7));
assertArrayEquals(destShape, output.shape()); assertArrayEquals(destShape, output.shape());
} }
} }
@ -216,7 +223,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testBroadCasting() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadCasting(Nd4jBackend backend) {
INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE); INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE);
INDArray ret = first.broadcast(3, 4); INDArray ret = first.broadcast(3, 4);
INDArray testRet = Nd4j.create(new double[][] {{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}}); INDArray testRet = Nd4j.create(new double[][] {{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}});
@ -229,14 +238,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testOneTensor() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneTensor(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1); INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1);
INDArray matrixToBroadcast = Nd4j.ones(1, 1); INDArray matrixToBroadcast = Nd4j.ones(1, 1);
assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr); assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr);
} }
@Test @Test
public void testSortWithIndicesDescending() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortWithIndicesDescending(Nd4jBackend backend) {
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
//indices,data //indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, false); INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, false);
@ -247,7 +260,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testSortDeadlock() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortDeadlock(Nd4jBackend backend) {
val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768); val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768);
val sorted = Nd4j.sort(toSort.dup(), 1, false); val sorted = Nd4j.sort(toSort.dup(), 1, false);
@ -255,7 +270,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testSortWithIndices() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortWithIndices(Nd4jBackend backend) {
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
//indices,data //indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, true); INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, true);
@ -266,14 +283,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testNd4jSortScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jSortScalar(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1); INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1);
INDArray sorted = Nd4j.sort(linspace, 1, false); INDArray sorted = Nd4j.sort(linspace, 1, false);
// System.out.println(sorted); // System.out.println(sorted);
} }
@Test @Test
public void testSwapAxesFortranOrder() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSwapAxesFortranOrder(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE); INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE);
for (int i = 0; i < n.slices(); i++) { for (int i = 0; i < n.slices(); i++) {
INDArray nSlice = n.slice(i); INDArray nSlice = n.slice(i);
@ -292,7 +313,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testDimShuffle() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDimShuffle(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false}); INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false});
assertTrue(Arrays.equals(new long[] {2, 1, 2}, twoOneTwo.shape())); assertTrue(Arrays.equals(new long[] {2, 1, 2}, twoOneTwo.shape()));
@ -303,7 +326,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGetVsGetScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetVsGetScalar(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
float element = a.getFloat(0, 1); float element = a.getFloat(0, 1);
double element2 = a.getDouble(0, 1); double element2 = a.getDouble(0, 1);
@ -316,7 +341,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testDivide() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDivide(Nd4jBackend backend) {
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2}); INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
INDArray div = two.div(two); INDArray div = two.div(two);
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage()); assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage());
@ -330,7 +357,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testSigmoid() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSigmoid(Nd4jBackend backend) {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f}); INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
INDArray sigmoid = Transforms.sigmoid(n, false); INDArray sigmoid = Transforms.sigmoid(n, false);
@ -339,7 +368,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testNeg() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNeg(Nd4jBackend backend) {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4}); INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
INDArray neg = Transforms.neg(n); INDArray neg = Transforms.neg(n);
@ -349,7 +380,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testCosineSim() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCosineSim(Nd4jBackend backend) {
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
double sim = Transforms.cosineSim(vec1, vec2); double sim = Transforms.cosineSim(vec1, vec2);
@ -364,7 +397,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testExp() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExp(Nd4jBackend backend) {
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f}); INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f});
INDArray exped = Transforms.exp(n); INDArray exped = Transforms.exp(n);
@ -374,7 +409,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalar(Nd4jBackend backend) {
INDArray a = Nd4j.scalar(1.0f); INDArray a = Nd4j.scalar(1.0f);
assertEquals(true, a.isScalar()); assertEquals(true, a.isScalar());
@ -386,7 +423,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testWrap() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testWrap(Nd4jBackend backend) {
int[] shape = {2, 4}; int[] shape = {2, 4};
INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]); INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]);
INDArray n = d; INDArray n = d;
@ -411,7 +450,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGetRowFortran() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowFortran(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2}); INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2});
INDArray column = Nd4j.create(new float[] {1, 3}); INDArray column = Nd4j.create(new float[] {1, 3});
INDArray column2 = Nd4j.create(new float[] {2, 4}); INDArray column2 = Nd4j.create(new float[] {2, 4});
@ -424,7 +465,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGetColumnFortran() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnFortran(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2}); INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
INDArray column = Nd4j.create(new double[] {1, 2}); INDArray column = Nd4j.create(new double[] {1, 2});
INDArray column2 = Nd4j.create(new double[] {3, 4}); INDArray column2 = Nd4j.create(new double[] {3, 4});
@ -438,7 +481,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testGetColumns() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumns(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE); INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE);
// log.info("Original: {}", matrix); // log.info("Original: {}", matrix);
INDArray matrixGet = matrix.getColumns(1, 2); INDArray matrixGet = matrix.getColumns(1, 2);
@ -452,7 +497,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testVectorInit() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorInit(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(); DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data();
INDArray arr = Nd4j.create(data, new long[] {1, 4}); INDArray arr = Nd4j.create(data, new long[] {1, 4});
assertEquals(true, arr.isRowVector()); assertEquals(true, arr.isRowVector());
@ -465,7 +512,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testAssignOffset() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssignOffset(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(5, 5); INDArray arr = Nd4j.ones(5, 5);
INDArray row = arr.slice(1); INDArray row = arr.slice(1);
row.assign(1); row.assign(1);
@ -473,7 +522,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testColumns() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumns(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE); INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE);
INDArray column = Nd4j.create(new double[] {1, 2, 3}); INDArray column = Nd4j.create(new double[] {1, 2, 3});
arr.putColumn(0, column); arr.putColumn(0, column);
@ -511,7 +562,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testPutRow() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRow(Nd4jBackend backend) {
INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray n = d.dup(); INDArray n = d.dup();
@ -570,7 +623,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testInplaceTranspose() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInplaceTranspose(Nd4jBackend backend) {
INDArray test = Nd4j.rand(3, 4); INDArray test = Nd4j.rand(3, 4);
INDArray orig = test.dup(); INDArray orig = test.dup();
INDArray transposei = test.transposei(); INDArray transposei = test.transposei();
@ -585,7 +640,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testMmulF() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulF(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data(); DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data();
INDArray n = Nd4j.create(data, new long[] {1, 10}); INDArray n = Nd4j.create(data, new long[] {1, 10});
@ -603,7 +660,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testRowsColumns() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowsColumns(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data(); DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data();
INDArray rows = Nd4j.create(data, new long[] {2, 3}); INDArray rows = Nd4j.create(data, new long[] {2, 3});
assertEquals(2, rows.rows()); assertEquals(2, rows.rows());
@ -619,7 +678,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testTranspose() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTranspose(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4}); INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4});
INDArray transpose = n.transpose(); INDArray transpose = n.transpose();
assertEquals(n.length(), transpose.length()); assertEquals(n.length(), transpose.length());
@ -647,7 +708,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testAddMatrix() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddMatrix(Nd4jBackend backend) {
INDArray five = Nd4j.ones(5); INDArray five = Nd4j.ones(5);
five.addi(five.dup()); five.addi(five.dup());
INDArray twos = Nd4j.valueArrayOf(5, 2); INDArray twos = Nd4j.valueArrayOf(5, 2);
@ -658,7 +721,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testMMul() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMMul(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}}); INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
@ -669,7 +734,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testPutSlice() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutSlice(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3); INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3);
INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3); INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3);
Nd4j.exec(new PrintVariable(newSlice)); Nd4j.exec(new PrintVariable(newSlice));
@ -680,7 +747,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testRowVectorMultipleIndices() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4); INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
linear.putScalar(new long[] {0, 1}, 1); linear.putScalar(new long[] {0, 1}, 1);
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage()); assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage());
@ -689,7 +758,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testDim1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDim1(Nd4jBackend backend) {
INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1); INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1);
INDArray same = sum.dup(); INDArray same = sum.dup();
assertEquals(same.sum(1), sum.reshape(2)); assertEquals(same.sum(1), sum.reshape(2));
@ -697,7 +768,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testEps() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEps(Nd4jBackend backend) {
val ones = Nd4j.ones(5); val ones = Nd4j.ones(5);
val res = Nd4j.createUninitialized(DataType.BOOL, 5); val res = Nd4j.createUninitialized(DataType.BOOL, 5);
assertTrue(Nd4j.getExecutioner().exec(new Eps(ones, ones, res)).all()); assertTrue(Nd4j.getExecutioner().exec(new Eps(ones, ones, res)).all());
@ -705,7 +778,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testLogDouble() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogDouble(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE);
INDArray log = Transforms.log(linspace); INDArray log = Transforms.log(linspace);
INDArray assertion = Nd4j.create(new double[] {0, 0.6931471805599453, 1.0986122886681098, 1.3862943611198906, 1.6094379124341005, 1.791759469228055}); INDArray assertion = Nd4j.create(new double[] {0, 0.6931471805599453, 1.0986122886681098, 1.3862943611198906, 1.6094379124341005, 1.791759469228055});
@ -713,28 +788,36 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testVectorSum() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum(Nd4jBackend backend) {
INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1); assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
} }
@Test @Test
public void testVectorSum2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum2(Nd4jBackend backend) {
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1); assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
} }
@Test @Test
public void testVectorSum3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum3(Nd4jBackend backend) {
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals(lin, lin2); assertEquals(lin, lin2);
} }
@Test @Test
public void testSmallSum() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSmallSum(Nd4jBackend backend) {
INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007}); INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007});
base.addi(1e-12); base.addi(1e-12);
INDArray assertion = Nd4j.create(new double[] {5.84333433, 3.054001}); INDArray assertion = Nd4j.create(new double[] {5.84333433, 3.054001});
@ -745,7 +828,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testPermute() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4}); INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4});
INDArray transpose = n.transpose(); INDArray transpose = n.transpose();
INDArray permute = n.permute(1, 0); INDArray permute = n.permute(1, 0);
@ -774,7 +859,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testAppendBias() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAppendBias(Nd4jBackend backend) {
INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose(); INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray test = Nd4j.appendBias(rand); INDArray test = Nd4j.appendBias(rand);
INDArray assertion = Nd4j.toFlattened(rand, Nd4j.scalar(DataType.DOUBLE, 1.0)).reshape(-1, 1); INDArray assertion = Nd4j.toFlattened(rand, Nd4j.scalar(DataType.DOUBLE, 1.0)).reshape(-1, 1);
@ -782,7 +869,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testRand() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRand(Nd4jBackend backend) {
INDArray rand = Nd4j.randn(5, 5); INDArray rand = Nd4j.randn(5, 5);
Nd4j.getDistributions().createUniform(0.4, 4).sample(5); Nd4j.getDistributions().createUniform(0.4, 4).sample(5);
Nd4j.getDistributions().createNormal(1, 5).sample(10); Nd4j.getDistributions().createNormal(1, 5).sample(10);
@ -794,7 +883,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testIdentity() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIdentity(Nd4jBackend backend) {
INDArray eye = Nd4j.eye(5); INDArray eye = Nd4j.eye(5);
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape())); assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
eye = Nd4j.eye(5); eye = Nd4j.eye(5);
@ -805,7 +896,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testColumnVectorOpsFortran() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumnVectorOpsFortran(Nd4jBackend backend) {
INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1}); INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1});
twoByTwo.addiColumnVector(toAdd); twoByTwo.addiColumnVector(toAdd);
@ -816,7 +909,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testRSubi() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRSubi(Nd4jBackend backend) {
INDArray n2 = Nd4j.ones(2); INDArray n2 = Nd4j.ones(2);
INDArray n2Assertion = Nd4j.zeros(2); INDArray n2Assertion = Nd4j.zeros(2);
INDArray nRsubi = n2.rsubi(1); INDArray nRsubi = n2.rsubi(1);
@ -826,7 +921,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testAssign() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
vector.assign(1); vector.assign(1);
assertEquals(Nd4j.ones(5).castTo(DataType.DOUBLE), vector); assertEquals(Nd4j.ones(5).castTo(DataType.DOUBLE), vector);
@ -843,7 +940,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testAddScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddScalar(Nd4jBackend backend) {
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0); INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
INDArray rdiv = div.add(1); INDArray rdiv = div.add(1);
INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 5.0); INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 5.0);
@ -851,7 +950,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testRdivScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRdivScalar(Nd4jBackend backend) {
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0); INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
INDArray rdiv = div.rdiv(1); INDArray rdiv = div.rdiv(1);
INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 0.25); INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 0.25);
@ -859,7 +960,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testRDivi() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRDivi(Nd4jBackend backend) {
INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0); INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0);
INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5); INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5);
INDArray nRsubi = n2.rdivi(2); INDArray nRsubi = n2.rdivi(2);
@ -869,7 +972,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testNumVectorsAlongDimension() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNumVectorsAlongDimension(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2);
assertEquals(12, arr.vectorsAlongDimension(2)); assertEquals(12, arr.vectorsAlongDimension(2));
} }
@ -877,7 +982,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testBroadCast() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadCast(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
INDArray broadCasted = n.broadcast(5, 4); INDArray broadCasted = n.broadcast(5, 4);
for (int i = 0; i < broadCasted.rows(); i++) { for (int i = 0; i < broadCasted.rows(); i++) {
@ -899,7 +1006,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testMatrix() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrix(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2}); INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2});
INDArray row = arr.getRow(0); INDArray row = arr.getRow(0);
@ -909,7 +1018,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testPutRowGetRowOrdering() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowGetRowOrdering(Nd4jBackend backend) {
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray put = Nd4j.create(new double[] {5, 6}); INDArray put = Nd4j.create(new double[] {5, 6});
row1.putRow(1, put); row1.putRow(1, put);
@ -931,7 +1042,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testSumWithRow1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSumWithRow1(Nd4jBackend backend) {
//Works: //Works:
INDArray array2d = Nd4j.ones(1, 10); INDArray array2d = Nd4j.ones(1, 10);
array2d.sum(0); //OK array2d.sum(0); //OK
@ -962,7 +1075,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testSumWithRow2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSumWithRow2(Nd4jBackend backend) {
//All sums in this method execute without exceptions. //All sums in this method execute without exceptions.
INDArray array3d = Nd4j.ones(2, 10, 10); INDArray array3d = Nd4j.ones(2, 10, 10);
array3d.sum(0); array3d.sum(0);
@ -985,7 +1100,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testPutRowFortran() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowFortran(Nd4jBackend backend) {
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE); INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE);
INDArray put = Nd4j.create(new double[] {5, 6}); INDArray put = Nd4j.create(new double[] {5, 6});
row1.putRow(1, put); row1.putRow(1, put);
@ -998,7 +1115,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testElementWiseOps() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testElementWiseOps(Nd4jBackend backend) {
INDArray n1 = Nd4j.scalar(1); INDArray n1 = Nd4j.scalar(1);
INDArray n2 = Nd4j.scalar(2); INDArray n2 = Nd4j.scalar(2);
INDArray nClone = n1.add(n2); INDArray nClone = n1.add(n2);
@ -1021,7 +1140,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
public void testRollAxis() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRollAxis(Nd4jBackend backend) {
INDArray toRoll = Nd4j.ones(3, 4, 5, 6); INDArray toRoll = Nd4j.ones(3, 4, 5, 6);
assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape()); assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape());
val shape = Nd4j.rollAxis(toRoll, 3).shape(); val shape = Nd4j.rollAxis(toRoll, 3).shape();
@ -1030,20 +1151,22 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test @Test
@Disabled @Disabled
public void testTensorDot() { public void testTensorDot(Nd4jBackend backend) {
INDArray oneThroughSixty = Nd4j.arange(60).reshape('f', 3, 4, 5).castTo(DataType.DOUBLE); INDArray oneThroughSixty = Nd4j.arange(60).reshape('f', 3, 4, 5).castTo(DataType.DOUBLE);
INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape('f', 4, 3, 2).castTo(DataType.DOUBLE); INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape('f', 4, 3, 2).castTo(DataType.DOUBLE);
INDArray result = Nd4j.tensorMmul(oneThroughSixty, oneThroughTwentyFour, new int[][] {{1, 0}, {0, 1}}); INDArray result = Nd4j.tensorMmul(oneThroughSixty, oneThroughTwentyFour, new int[][] {{1, 0}, {0, 1}});
assertArrayEquals(new long[] {5, 2}, result.shape()); assertArrayEquals(new long[] {5, 2}, result.shape());
INDArray assertion = Nd4j.create(new double[][] {{440., 1232.}, {1232., 3752.}, {2024., 6272.}, {2816., 8792.}, INDArray assertion = Nd4j.create(new double[][] {{440., 1232.}, {1232., 3752.}, {2024., 6272.}, {2816., 8792.},
{3608., 11312.}}); {3608., 11312.}});
assertEquals(assertion, result); assertEquals(assertion, result);
} }
@Test @Test
public void testNegativeShape() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeShape(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
INDArray reshaped = linspace.reshape(-1, 2); INDArray reshaped = linspace.reshape(-1, 2);
assertArrayEquals(new long[] {2, 2}, reshaped.shape()); assertArrayEquals(new long[] {2, 2}, reshaped.shape());
@ -1055,7 +1178,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGetColumnGetRow() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnGetRow(Nd4jBackend backend) {
INDArray row = Nd4j.ones(1, 5); INDArray row = Nd4j.ones(1, 5);
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
INDArray col = row.getColumn(i); INDArray col = row.getColumn(i);
@ -1070,7 +1195,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testDupAndDupWithOrder() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDupAndDupWithOrder(Nd4jBackend backend) {
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE); List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
int count = 0; int count = 0;
for (Pair<INDArray, String> pair : testInputs) { for (Pair<INDArray, String> pair : testInputs) {
@ -1092,7 +1219,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
} }
@Test @Test
public void testToOffsetZeroCopy() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToOffsetZeroCopy(Nd4jBackend backend) {
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE); List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
int cnt = 0; int cnt = 0;

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,18 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class Nd4jTestsComparisonC extends BaseNd4jTest { public class Nd4jTestsComparisonC extends BaseNd4jTestWithBackends {
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonC.class); private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonC.class);
public static final int SEED = 123; public static final int SEED = 123;
DataType initialType; DataType initialType = Nd4j.dataType();
public Nd4jTestsComparisonC(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
@ -73,7 +70,9 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest {
@Test @Test
public void testGemmWithOpsCommonsMath() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
@ -140,13 +139,13 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest {
private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first, private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first,
Pair<INDArray, String> second) { Pair<INDArray, String> second) {
return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")"; return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")";
} }
private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha, private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha,
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) { double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
return i + "," + j + " - gemm(tA=" + transposeA + ",tB=" + transposeB + ",alpha=" + alpha + ",beta=" + beta return i + "," + j + " - gemm(tA=" + transposeA + ",tB=" + transposeB + ",alpha=" + alpha + ",beta=" + beta
+ "). A=" + first.getSecond() + ", B=" + second.getSecond(); + "). A=" + first.getSecond() + ", B=" + second.getSecond();
} }
} }

View File

@ -25,8 +25,9 @@ import org.apache.commons.math3.linear.RealMatrix;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -43,18 +44,14 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@RunWith(Parameterized.class)
public class Nd4jTestsComparisonFortran extends BaseNd4jTest { public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonFortran.class); private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonFortran.class);
public static final int SEED = 123; public static final int SEED = 123;
DataType initialType; DataType initialType = Nd4j.dataType();
public Nd4jTestsComparisonFortran(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
@ -75,7 +72,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testCrash() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCrash(Nd4jBackend backend) {
INDArray array3d = Nd4j.ones(1, 10, 10); INDArray array3d = Nd4j.ones(1, 10, 10);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0); Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 1); Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 1);
@ -85,7 +84,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testMmulWithOpsCommonsMath() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
@ -100,7 +101,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGemmWithOpsCommonsMath() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
@ -156,7 +159,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testGemvApacheCommons() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemvApacheCommons(Nd4jBackend backend) {
int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8}; int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8};
int[] colsArr = new int[] {2, 1, 10, 2, 1, 10}; int[] colsArr = new int[] {2, 1, 10, 2, 1, 10};
@ -197,7 +202,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
assertArrayEquals(new long[] {rows, 1}, gemv.shape()); assertArrayEquals(new long[] {rows, 1}, gemv.shape());
assertArrayEquals(new int[] {rows, 1}, assertArrayEquals(new int[] {rows, 1},
new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()}); new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()});
//Check entries: //Check entries:
for (int r = 0; r < rows; r++) { for (int r = 0; r < rows; r++) {
@ -211,7 +216,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testAddSubtractWithOpsCommonsMath() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddSubtractWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
for (int i = 0; i < first.size(); i++) { for (int i = 0; i < first.size(); i++) {
@ -229,7 +236,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
@Test @Test
public void testMulDivOnCheckUtilMatrices() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMulDivOnCheckUtilMatrices(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
for (int i = 0; i < first.size(); i++) { for (int i = 0; i < first.size(); i++) {
@ -245,13 +254,13 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
} }
private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first, private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first,
Pair<INDArray, String> second) { Pair<INDArray, String> second) {
return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")"; return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")";
} }
private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha, private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha,
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) { double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
return i + "," + j + " - gemm(tA=" + transposeA + ",tB= " + transposeB + ",alpha=" + alpha + ",beta= " + beta return i + "," + j + " - gemm(tA=" + transposeA + ",tB= " + transposeB + ",alpha=" + alpha + ",beta= " + beta
+ "). A=" + first.getSecond() + ", B=" + second.getSecond(); + "). A=" + first.getSecond() + ", B=" + second.getSecond();
} }
} }

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -36,18 +37,15 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class Nd4jTestsF extends BaseNd4jTest {
DataType initialType; public class Nd4jTestsF extends BaseNd4jTestWithBackends {
public Nd4jTestsF(Nd4jBackend backend) { DataType initialType = Nd4j.dataType();
super(backend);
this.initialType = Nd4j.dataType();
}
@Test @Test
public void testConcat3D_Vstack_F() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat3D_Vstack_F(Nd4jBackend backend) {
//Nd4j.getExecutioner().enableVerboseMode(true); //Nd4j.getExecutioner().enableVerboseMode(true);
//Nd4j.getExecutioner().enableDebugMode(true); //Nd4j.getExecutioner().enableDebugMode(true);
@ -79,7 +77,9 @@ public class Nd4jTestsF extends BaseNd4jTest {
@Test @Test
public void testSlice_1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice_1(Nd4jBackend backend) {
val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1); val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1);
val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1}); val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1});
val exp1 = Nd4j.create(new double[]{2, 4}, new int[] {2, 1}); val exp1 = Nd4j.create(new double[]{2, 4}, new int[] {2, 1});

View File

@ -22,8 +22,9 @@ package org.nd4j.linalg;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -34,15 +35,13 @@ import java.util.*;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@RunWith(Parameterized.class)
public class ShufflesTests extends BaseNd4jTest {
public ShufflesTests(Nd4jBackend backend) { public class ShufflesTests extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
public void testSimpleShuffle1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle1(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(10, 10); INDArray array = Nd4j.zeros(10, 10);
for (int x = 0; x < 10; x++) { for (int x = 0; x < 10; x++) {
array.getRow(x).assign(x); array.getRow(x).assign(x);
@ -64,7 +63,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testSimpleShuffle2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle2(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(10, 10); INDArray array = Nd4j.zeros(10, 10);
for (int x = 0; x < 10; x++) { for (int x = 0; x < 10; x++) {
array.getColumn(x).assign(x); array.getColumn(x).assign(x);
@ -79,7 +80,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testSimpleShuffle3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle3(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(11, 10); INDArray array = Nd4j.zeros(11, 10);
for (int x = 0; x < 11; x++) { for (int x = 0; x < 11; x++) {
array.getRow(x).assign(x); array.getRow(x).assign(x);
@ -95,7 +98,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testSymmetricShuffle1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle1(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10); INDArray features = Nd4j.zeros(10, 10);
INDArray labels = Nd4j.zeros(10, 3); INDArray labels = Nd4j.zeros(10, 3);
for (int x = 0; x < 10; x++) { for (int x = 0; x < 10; x++) {
@ -133,7 +138,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testSymmetricShuffle2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle2(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10, 20); INDArray features = Nd4j.zeros(10, 10, 20);
INDArray labels = Nd4j.zeros(10, 10, 3); INDArray labels = Nd4j.zeros(10, 10, 3);
@ -171,7 +178,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testSymmetricShuffle3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle3(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10, 20); INDArray features = Nd4j.zeros(10, 10, 20);
INDArray featuresMask = Nd4j.zeros(10, 20); INDArray featuresMask = Nd4j.zeros(10, 20);
INDArray labels = Nd4j.zeros(10, 10, 3); INDArray labels = Nd4j.zeros(10, 10, 3);
@ -236,7 +245,9 @@ public class ShufflesTests extends BaseNd4jTest {
* @throws Exception * @throws Exception
*/ */
@Test @Test
public void testHalfVectors1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testHalfVectors1(Nd4jBackend backend) {
int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20); int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20);
int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20); int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20);
@ -257,7 +268,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testInterleavedVector1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterleavedVector1(Nd4jBackend backend) {
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20); int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20);
int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20); int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20);
@ -278,7 +291,9 @@ public class ShufflesTests extends BaseNd4jTest {
} }
@Test @Test
public void testInterleavedVector3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterleavedVector3(Nd4jBackend backend) {
for (int e = 0; e < 1000; e++) { for (int e = 0; e < 1000; e++) {
int length = e + 256; //RandomUtils.nextInt(121, 2073); int length = e + 256; //RandomUtils.nextInt(121, 2073);
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(System.currentTimeMillis()), length); int[] array1 = ArrayUtil.buildInterleavedVector(new Random(System.currentTimeMillis()), length);

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.eigen.Eigen; import org.nd4j.linalg.eigen.Eigen;
@ -35,16 +36,11 @@ import org.nd4j.common.util.ArrayUtil;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@Slf4j @Slf4j
public class TestEigen extends BaseNd4jTest { public class TestEigen extends BaseNd4jTestWithBackends {
protected DataType initialType; protected DataType initialType = Nd4j.dataType();
public TestEigen(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
public void before() { public void before() {
@ -59,7 +55,9 @@ public class TestEigen extends BaseNd4jTest {
// test of functions added by Luke Czapla // test of functions added by Luke Czapla
// Compares solution of A x = L x to solution to A x = L B x when it is simple // Compares solution of A x = L x to solution to A x = L B x when it is simple
@Test @Test
public void test2Syev() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2Syev(Nd4jBackend backend) {
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
Nd4j.setDefaultDataTypes(dt, dt); Nd4j.setDefaultDataTypes(dt, dt);
@ -78,7 +76,9 @@ public class TestEigen extends BaseNd4jTest {
} }
@Test @Test
public void testSyev() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSyev(Nd4jBackend backend) {
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
//log.info("Datatype: {}", dt); //log.info("Datatype: {}", dt);
Nd4j.setDefaultDataTypes(dt, dt); Nd4j.setDefaultDataTypes(dt, dt);

View File

@ -24,23 +24,23 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.util.ArrayUtil; import org.nd4j.common.util.ArrayUtil;
@RunWith(Parameterized.class)
@Slf4j @Slf4j
public class ToStringTest extends BaseNd4jTest { public class ToStringTest extends BaseNd4jTestWithBackends {
public ToStringTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testToString() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToString(Nd4jBackend backend) throws Exception {
assertEquals("[ 1, 2, 3]", assertEquals("[ 1, 2, 3]",
Nd4j.createFromArray(1, 2, 3).toString()); Nd4j.createFromArray(1, 2, 3).toString());
@ -58,6 +58,8 @@ public class ToStringTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToStringScalars(){ public void testToStringScalars(){
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32}; DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"}; String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.activations;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.impl.ActivationCube; import org.nd4j.linalg.activations.impl.ActivationCube;
import org.nd4j.linalg.activations.impl.ActivationELU; import org.nd4j.linalg.activations.impl.ActivationELU;
import org.nd4j.linalg.activations.impl.ActivationGELU; import org.nd4j.linalg.activations.impl.ActivationGELU;
@ -55,12 +56,9 @@ import java.util.List;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestActivation extends BaseNd4jTest {
public TestActivation(Nd4jBackend backend) { public class TestActivation extends BaseNd4jTestWithBackends {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
@ -79,7 +77,9 @@ public class TestActivation extends BaseNd4jTest {
} }
@Test @Test
public void testRelu(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRelu(Nd4jBackend backend){
Double[] max = {null, 6.0, 2.5, 5.0}; Double[] max = {null, 6.0, 2.5, 5.0};
Double[] threshold = {0.0, 0.0, 0.75, 0.2}; Double[] threshold = {0.0, 0.0, 0.75, 0.2};
@ -131,30 +131,32 @@ public class TestActivation extends BaseNd4jTest {
} }
@Test @Test
public void testJson() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJson(Nd4jBackend backend) throws Exception {
IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25), IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25),
new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(), new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(),
new ActivationLReLU(0.25), new ActivationRationalTanh(), new ActivationReLU(), new ActivationLReLU(0.25), new ActivationRationalTanh(), new ActivationReLU(),
new ActivationRReLU(0.25, 0.5), new ActivationSigmoid(), new ActivationSoftmax(), new ActivationRReLU(0.25, 0.5), new ActivationSigmoid(), new ActivationSoftmax(),
new ActivationSoftPlus(), new ActivationSoftSign(), new ActivationTanH(), new ActivationGELU(), new ActivationGELU(true)}; new ActivationSoftPlus(), new ActivationSoftSign(), new ActivationTanH(), new ActivationGELU(), new ActivationGELU(true)};
String[][] expectedFields = new String[][] {{"@class"}, //Cube String[][] expectedFields = new String[][] {{"@class"}, //Cube
{"@class", "alpha"}, //ELU {"@class", "alpha"}, //ELU
{"@class"}, //Hard sigmoid {"@class"}, //Hard sigmoid
{"@class"}, //Hard TanH {"@class"}, //Hard TanH
{"@class"}, //Identity {"@class"}, //Identity
{"@class", "alpha"}, //Leaky Relu {"@class", "alpha"}, //Leaky Relu
{"@class"}, //rational tanh {"@class"}, //rational tanh
{"@class", "max", "negativeSlope", "threshold"}, //relu {"@class", "max", "negativeSlope", "threshold"}, //relu
{"@class", "l", "u"}, //rrelu {"@class", "l", "u"}, //rrelu
{"@class"}, //sigmoid {"@class"}, //sigmoid
{"@class"}, //Softmax {"@class"}, //Softmax
{"@class"}, //Softplus {"@class"}, //Softplus
{"@class"}, //Softsign {"@class"}, //Softsign
{"@class"}, //Tanh {"@class"}, //Tanh
{"@class", "precise"}, //GELU {"@class", "precise"}, //GELU
{"@class", "precise"} //GELU precise {"@class", "precise"} //GELU precise
}; };
@ -172,7 +174,7 @@ public class TestActivation extends BaseNd4jTest {
String[] expFields = expectedFields[i]; String[] expFields = expectedFields[i];
String msg = activations[i].toString() + "\tExpected fields: " + Arrays.toString(expFields) String msg = activations[i].toString() + "\tExpected fields: " + Arrays.toString(expFields)
+ "\tActual fields: " + actualFieldsByName; + "\tActual fields: " + actualFieldsByName;
assertEquals(expFields.length, actualFieldsByName.size(),msg); assertEquals(expFields.length, actualFieldsByName.size(),msg);
for (String s : expFields) { for (String s : expFields) {

View File

@ -20,21 +20,20 @@
package org.nd4j.linalg.api; package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.nd4j.linalg.factory.Environment; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
public class TestBackend extends BaseNd4jTest { public class TestBackend extends BaseNd4jTestWithBackends {
public TestBackend(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void TestBuildInfo(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBuildInfo(Nd4jBackend backend){
System.out.println("Backend build info: " + backend.buildInfo()); System.out.println("Backend build info: " + backend.buildInfo());
} }
} }

View File

@ -20,26 +20,27 @@
package org.nd4j.linalg.api; package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Environment; import org.nd4j.linalg.factory.Environment;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
public class TestEnvironment extends BaseNd4jTest { public class TestEnvironment extends BaseNd4jTestWithBackends {
public TestEnvironment(Nd4jBackend backend) {
super(backend);
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';
} }
@Test @Test
public void testEnvironment(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEnvironment(Nd4jBackend backend){
Environment e = Nd4j.getEnvironment(); Environment e = Nd4j.getEnvironment();
System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion()); System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion());
System.out.println("CPU: " + e.isCPU()); System.out.println("CPU: " + e.isCPU());

View File

@ -26,7 +26,9 @@ import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -40,16 +42,12 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class TestNDArrayCreation extends BaseNd4jTest { public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
public TestNDArrayCreation(Nd4jBackend backend) {
super(backend);
}
@Test @Test
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") @ParameterizedTest
public void testBufferCreation() { @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBufferCreation(Nd4jBackend backend) {
DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2}); DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2});
Pointer pointer = dataBuffer.pointer(); Pointer pointer = dataBuffer.pointer();
FloatPointer floatPointer = new FloatPointer(pointer); FloatPointer floatPointer = new FloatPointer(pointer);
@ -69,6 +67,8 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test @Test
@Disabled @Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateNpy() throws Exception { public void testCreateNpy() throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile()); INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile());
assertEquals(2, arrCreate.size(0)); assertEquals(2, arrCreate.size(0));
@ -82,7 +82,9 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test @Test
@Disabled @Disabled
public void testCreateNpz() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateNpz(Nd4jBackend backend) throws Exception {
Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile()); Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile());
assertEquals(true, map.containsKey("x")); assertEquals(true, map.containsKey("x"));
assertEquals(true, map.containsKey("y")); assertEquals(true, map.containsKey("y"));
@ -100,8 +102,7 @@ public class TestNDArrayCreation extends BaseNd4jTest {
} }
@Test @Test
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public void testCreateNpy3(Nd4jBackend backend) throws Exception {
public void testCreateNpy3() throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile()); INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
assertEquals(8, arrCreate.length()); assertEquals(8, arrCreate.length());
assertEquals(3, arrCreate.rank()); assertEquals(3, arrCreate.rank());
@ -113,7 +114,7 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test @Test
@Disabled // this is endless test @Disabled // this is endless test
public void testEndlessAllocation() { public void testEndlessAllocation(Nd4jBackend backend) {
Nd4j.getEnvironment().setMaxSpecialMemory(1); Nd4j.getEnvironment().setMaxSpecialMemory(1);
while (true) { while (true) {
val arr = Nd4j.createUninitialized(DataType.FLOAT, 100000000); val arr = Nd4j.createUninitialized(DataType.FLOAT, 100000000);

View File

@ -21,24 +21,23 @@
package org.nd4j.linalg.api; package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil; import org.nd4j.common.util.ArrayUtil;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
public class TestNDArrayCreationUtil extends BaseNd4jTest { public class TestNDArrayCreationUtil extends BaseNd4jTestWithBackends {
public TestNDArrayCreationUtil(Nd4jBackend backend) {
super(backend);
}
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapes() { public void testShapes() {
long[] shape2d = {2, 3}; long[] shape2d = {2, 3};

View File

@ -21,20 +21,21 @@
package org.nd4j.linalg.api; package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
public class TestNamespaces extends BaseNd4jTest { public class TestNamespaces extends BaseNd4jTestWithBackends {
public TestNamespaces(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testBitwiseSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBitwiseSimple(Nd4jBackend backend){
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT); INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT); INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
@ -50,7 +51,9 @@ public class TestNamespaces extends BaseNd4jTest {
} }
@Test @Test
public void testMathSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMathSimple(Nd4jBackend backend) {
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1); INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1);
INDArray abs = Nd4j.math.abs(x); INDArray abs = Nd4j.math.abs(x);
// System.out.println(x); // System.out.println(x);
@ -65,7 +68,9 @@ public class TestNamespaces extends BaseNd4jTest {
} }
@Test @Test
public void testRandomSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomSimple(Nd4jBackend backend){
INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10); INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10);
// System.out.println(normal); // System.out.println(normal);
INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10); INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10);
@ -73,7 +78,9 @@ public class TestNamespaces extends BaseNd4jTest {
} }
@Test @Test
public void testNeuralNetworkSimple(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNeuralNetworkSimple(Nd4jBackend backend){
INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10)); INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10));
// System.out.println(out); // System.out.println(out);
INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1); INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.blas; package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -31,15 +32,14 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class LapackTest extends BaseNd4jTest { public class LapackTest extends BaseNd4jTestWithBackends {
public LapackTest(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testQRSquare() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testQRSquare(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9}); INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9});
A = A.reshape('c', 3, 3); A = A.reshape('c', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape()); INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -57,7 +57,9 @@ public class LapackTest extends BaseNd4jTest {
} }
@Test @Test
public void testQRRect() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testQRRect(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
A = A.reshape('f', 4, 3); A = A.reshape('f', 4, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape()); INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -75,7 +77,9 @@ public class LapackTest extends BaseNd4jTest {
} }
@Test @Test
public void testCholeskyL() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCholeskyL(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,}); INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,});
A = A.reshape('c', 3, 3); A = A.reshape('c', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape()); INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -92,7 +96,9 @@ public class LapackTest extends BaseNd4jTest {
} }
@Test @Test
public void testCholeskyU() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCholeskyU(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,}); INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,});
A = A.reshape('f', 3, 3); A = A.reshape('f', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape()); INDArray O = Nd4j.create(A.dataType(), A.shape());

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.blas;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -35,14 +36,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class Level1Test extends BaseNd4jTest { public class Level1Test extends BaseNd4jTestWithBackends {
public Level1Test(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testDot() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDot(Nd4jBackend backend) {
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1); assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1);
@ -55,7 +55,9 @@ public class Level1Test extends BaseNd4jTest {
} }
@Test @Test
public void testAxpy() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAxpy(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray row = matrix.getRow(1); INDArray row = matrix.getRow(1);
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row); Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
@ -64,7 +66,9 @@ public class Level1Test extends BaseNd4jTest {
} }
@Test @Test
public void testAxpy2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAxpy2(Nd4jBackend backend) {
val rowX = Nd4j.create(new double[]{1, 2, 3, 4}); val rowX = Nd4j.create(new double[]{1, 2, 3, 4});
val rowY = Nd4j.create(new double[]{1, 2, 3, 4}); val rowY = Nd4j.create(new double[]{1, 2, 3, 4});
val exp = Nd4j.create(new double[]{3, 6, 9, 12}); val exp = Nd4j.create(new double[]{3, 6, 9, 12});

View File

@ -21,23 +21,23 @@
package org.nd4j.linalg.api.blas; package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class Level2Test extends BaseNd4jTest { public class Level2Test extends BaseNd4jTestWithBackends {
public Level2Test(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testGemv1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv1(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -51,7 +51,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv2(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -65,7 +67,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv3(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -79,7 +83,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv4() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv4(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -93,7 +99,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv5() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv5(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -109,7 +117,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv6() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv6(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -125,7 +135,9 @@ public class Level2Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemv7() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv7(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);

View File

@ -21,23 +21,23 @@
package org.nd4j.linalg.api.blas; package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class Level3Test extends BaseNd4jTest { public class Level3Test extends BaseNd4jTestWithBackends {
public Level3Test(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testGemm1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm1(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100); INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -47,7 +47,9 @@ public class Level3Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemm2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm2(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100); INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -57,7 +59,9 @@ public class Level3Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemm3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm3(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
@ -75,7 +79,9 @@ public class Level3Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemm4() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm4(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);
@ -92,7 +98,9 @@ public class Level3Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemm5() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm5(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
@ -106,7 +114,9 @@ public class Level3Test extends BaseNd4jTest {
} }
@Test @Test
public void testGemm6() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm6(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.blas.params; package org.nd4j.linalg.api.blas.params;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,16 +34,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class ParamsTestsF extends BaseNd4jTest {
public class ParamsTestsF extends BaseNd4jTestWithBackends {
public ParamsTestsF(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testGemm() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm (Nd4jBackend backend) {
INDArray a = Nd4j.create(2, 2); INDArray a = Nd4j.create(2, 2);
INDArray b = Nd4j.create(2, 3); INDArray b = Nd4j.create(2, 3);
INDArray c = Nd4j.create(2, 3); INDArray c = Nd4j.create(2, 3);

View File

@ -25,9 +25,10 @@ import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*; import org.bytedeco.javacpp.indexer.*;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -45,16 +46,15 @@ import java.nio.ByteOrder;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class DataBufferTests extends BaseNd4jTest {
public DataBufferTests(Nd4jBackend backend) { public class DataBufferTests extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
@Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657") @Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657")
public void testNoArgCreateBufferFromArray() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoArgCreateBufferFromArray(Nd4jBackend backend) {
//Tests here: //Tests here:
//1. Create from JVM array //1. Create from JVM array
@ -280,7 +280,9 @@ public class DataBufferTests extends BaseNd4jTest {
@Test @Test
public void testCreateTypedBuffer() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateTypedBuffer(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
@ -350,7 +352,9 @@ public class DataBufferTests extends BaseNd4jTest {
} }
@Test @Test
public void testAsBytes() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAsBytes(Nd4jBackend backend) {
INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1); INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1);
for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16, for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16,
@ -404,7 +408,9 @@ public class DataBufferTests extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEnsureLocation(){ public void testEnsureLocation(){
//https://github.com/eclipse/deeplearning4j/issues/8783 //https://github.com/eclipse/deeplearning4j/issues/8783
Nd4j.create(1); Nd4j.create(1);

View File

@ -23,9 +23,10 @@ package org.nd4j.linalg.api.buffer;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -33,13 +34,10 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@RunWith(Parameterized.class)
public class DataTypeValidationTests extends BaseNd4jTest {
DataType initialType;
public DataTypeValidationTests(Nd4jBackend backend) { public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
super(backend); DataType initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
public void setUp() { public void setUp() {
@ -48,7 +46,7 @@ public class DataTypeValidationTests extends BaseNd4jTest {
} }
@AfterEach @AfterEach
public void shutUp() { public void reset() {
Nd4j.setDataType(initialType); Nd4j.setDataType(initialType);
} }
@ -73,7 +71,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level1 blas * Testing level1 blas
*/ */
@Test() @Test()
public void testBlasValidation1() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> { assertThrows(ND4JIllegalStateException.class,() -> {
INDArray x = Nd4j.create(10); INDArray x = Nd4j.create(10);
@ -90,7 +90,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level2 blas * Testing level2 blas
*/ */
@Test() @Test()
public void testBlasValidation2() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation2(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> { assertThrows(RuntimeException.class,() -> {
INDArray a = Nd4j.create(100, 10); INDArray a = Nd4j.create(100, 10);
INDArray x = Nd4j.create(100); INDArray x = Nd4j.create(100);
@ -108,7 +110,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level3 blas * Testing level3 blas
*/ */
@Test() @Test()
public void testBlasValidation3() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation3(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
INDArray x = Nd4j.create(100, 100); INDArray x = Nd4j.create(100, 100);

View File

@ -26,9 +26,10 @@ import org.bytedeco.javacpp.indexer.Indexer;
import org.junit.jupiter.api.*; import org.junit.jupiter.api.*;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -54,34 +55,31 @@ import static org.junit.jupiter.api.Assertions.*;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public class DoubleDataBufferTest extends BaseNd4jTest { public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
DataType initialType; DataType initialType = Nd4j.dataType();
public DoubleDataBufferTest(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
public void before() { public void before(Nd4jBackend backend) {
DataTypeUtil.setDTypeForContext(DataType.DOUBLE); DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
} }
@AfterEach @AfterEach
public void after() { public void after(Nd4jBackend backend) {
DataTypeUtil.setDTypeForContext(initialType); DataTypeUtil.setDTypeForContext(initialType);
} }
@Test @Test
public void testPointerCreation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointerCreation(Nd4jBackend backend) {
DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4); DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4);
Indexer indexer = DoubleIndexer.create(floatPointer); Indexer indexer = DoubleIndexer.create(floatPointer);
DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.DOUBLE, 4, indexer); DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.DOUBLE, 4, indexer);
@ -89,8 +87,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
assertArrayEquals(other.asDouble(), buffer.asDouble(), 0.001); assertArrayEquals(other.asDouble(), buffer.asDouble(), 0.001);
} }
@Test @Test
public void testGetSet() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetSet(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4}; double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
double[] d2 = d.asDouble(); double[] d2 = d.asDouble();
@ -100,10 +100,12 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization2() throws Exception { public void testSerialization2() throws Exception {
INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10), INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10),
// Nd4j.ones(5,10).getRow(2) // Nd4j.ones(5,10).getRow(2)
}; };
for (INDArray a : arr) { for (INDArray a : arr) {
@ -128,7 +130,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization(@TempDir Path testDir) throws Exception { public void testSerialization(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
DataBuffer buf = Nd4j.createBuffer(5); DataBuffer buf = Nd4j.createBuffer(5);
@ -150,8 +154,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testDup() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDup(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4}; double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
DataBuffer d2 = d.dup(); DataBuffer d2 = d.dup();
@ -160,8 +166,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testPut() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPut(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4}; double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
d.put(0, 0.0); d.put(0, 0.0);
@ -171,8 +179,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testGetRange() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
double[] get = buffer.getDoublesAt(0, 3); double[] get = buffer.getDoublesAt(0, 3);
double[] data = new double[] {1, 2, 3}; double[] data = new double[] {1, 2, 3};
@ -186,8 +196,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testGetOffsetRange() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetOffsetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
double[] get = buffer.getDoublesAt(1, 3); double[] get = buffer.getDoublesAt(1, 3);
double[] data = new double[] {2, 3, 4}; double[] data = new double[] {2, 3, 4};
@ -201,8 +213,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testAssign() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
DataBuffer one = Nd4j.createBuffer(new double[] {1}); DataBuffer one = Nd4j.createBuffer(new double[] {1});
DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3}); DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3});
@ -212,8 +226,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testOffset() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffset(Nd4jBackend backend) {
DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2); DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2);
assertEquals(2, create.length()); assertEquals(2, create.length());
assertEquals(0, create.offset()); assertEquals(0, create.offset());
@ -222,8 +238,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testReallocation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4}); DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity()); assertEquals(4, buffer.capacity());
double[] old = buffer.asDouble(); double[] old = buffer.asDouble();
@ -232,10 +250,12 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
assertArrayEquals(old, Arrays.copyOf(buffer.asDouble(), 4), 1e-1); assertArrayEquals(old, Arrays.copyOf(buffer.asDouble(), 4), 1e-1);
} }
@Test @Test
public void testReallocationWorkspace() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4}); DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4});
@ -249,7 +269,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddressPointer(){ public void testAddressPointer(){
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){ if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
return; return;

View File

@ -27,7 +27,9 @@ import org.bytedeco.javacpp.indexer.Indexer;
import org.junit.jupiter.api.*; import org.junit.jupiter.api.*;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -54,14 +56,9 @@ import static org.junit.jupiter.api.Assertions.*;
* @author Adam Gibson * @author Adam Gibson
*/ */
@Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") @Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public class FloatDataBufferTest extends BaseNd4jTest { public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
DataType initialType; DataType initialType = Nd4j.dataType();
public FloatDataBufferTest(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
@BeforeEach @BeforeEach
public void before() { public void before() {
@ -76,7 +73,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testPointerCreation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointerCreation(Nd4jBackend backend) {
FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4); FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4);
Indexer indexer = FloatIndexer.create(floatPointer); Indexer indexer = FloatIndexer.create(floatPointer);
DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.FLOAT, 4, indexer); DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.FLOAT, 4, indexer);
@ -85,7 +84,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testGetSet() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetSet(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4}; float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
float[] d2 = d.asFloat(); float[] d2 = d.asFloat();
@ -96,7 +97,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testSerialization(@TempDir Path tempDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization(@TempDir Path tempDir,Nd4jBackend backend) throws Exception {
File dir = tempDir.toFile(); File dir = tempDir.toFile();
DataBuffer buf = Nd4j.createBuffer(5); DataBuffer buf = Nd4j.createBuffer(5);
String fileName = "buf.ser"; String fileName = "buf.ser";
@ -117,7 +120,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testDup() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDup(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4}; float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
DataBuffer d2 = d.dup(); DataBuffer d2 = d.dup();
@ -125,7 +130,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testToNio() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToNio(Nd4jBackend backend) {
DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT); DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT);
assertEquals(4, buff.length()); assertEquals(4, buff.length());
if (buff.allocationMode() == DataBuffer.AllocationMode.HEAP) if (buff.allocationMode() == DataBuffer.AllocationMode.HEAP)
@ -137,7 +144,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testPut() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPut(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4}; float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
d.put(0, 0.0); d.put(0, 0.0);
@ -148,7 +157,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testGetRange() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(0, 3); float[] get = buffer.getFloatsAt(0, 3);
float[] data = new float[] {1, 2, 3}; float[] data = new float[] {1, 2, 3};
@ -164,7 +175,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testGetOffsetRange() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetOffsetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(1, 3); float[] get = buffer.getFloatsAt(1, 3);
float[] data = new float[] {2, 3, 4}; float[] data = new float[] {2, 3, 4};
@ -181,7 +194,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test @Test
public void testAsBytes() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAsBytes(Nd4jBackend backend) {
INDArray arr = Nd4j.create(5); INDArray arr = Nd4j.create(5);
byte[] d = arr.data().asBytes(); byte[] d = arr.data().asBytes();
assertEquals(4 * 5, d.length,getFailureMessage()); assertEquals(4 * 5, d.length,getFailureMessage());
@ -191,7 +206,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testAssign() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
DataBuffer one = Nd4j.createBuffer(new double[] {1}); DataBuffer one = Nd4j.createBuffer(new double[] {1});
DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3}); DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3});
@ -201,7 +218,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testReadWrite() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWrite(Nd4jBackend backend) throws Exception {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos); DataOutputStream dos = new DataOutputStream(bos);
@ -215,7 +234,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testOffset() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffset(Nd4jBackend backend) {
DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2); DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2);
assertEquals(2, create.length()); assertEquals(2, create.length());
assertEquals(0, create.offset()); assertEquals(0, create.offset());
@ -225,7 +246,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testReallocation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4}); DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity()); assertEquals(4, buffer.capacity());
float[] old = buffer.asFloat(); float[] old = buffer.asFloat();
@ -236,7 +259,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testReallocationWorkspace() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
@ -253,7 +278,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
} }
@Test @Test
public void testAddressPointer(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddressPointer(Nd4jBackend backend){
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){ if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
return; return;
} }

View File

@ -23,7 +23,9 @@ package org.nd4j.linalg.api.buffer;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy; import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
@ -37,13 +39,12 @@ import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
public class IntDataBufferTests extends BaseNd4jTest { public class IntDataBufferTests extends BaseNd4jTestWithBackends {
public IntDataBufferTests(Nd4jBackend backend) {
super(backend);
}
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasicSerde1() throws Exception { public void testBasicSerde1() throws Exception {
@ -82,7 +83,9 @@ public class IntDataBufferTests extends BaseNd4jTest {
*/ */
@Test @Test
public void testReallocation() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4}); DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity()); assertEquals(4, buffer.capacity());
buffer.reallocate(6); buffer.reallocate(6);
@ -94,9 +97,11 @@ public class IntDataBufferTests extends BaseNd4jTest {
} }
@Test @Test
public void testReallocationWorkspace() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4}); DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing; package org.nd4j.linalg.api.indexing;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -37,17 +38,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class) public class IndexingTests extends BaseNd4jTestWithBackends {
public class IndexingTests extends BaseNd4jTest {
public IndexingTests(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testINDArrayIndexingEqualToRank() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingEqualToRank(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
{0,1,2}, {0,1,2},
@ -62,7 +61,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test @Test
public void testINDArrayIndexingLessThanRankSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingLessThanRankSimple(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
{0}, {0},
@ -76,7 +77,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test @Test
public void testINDArrayIndexingLessThanRankFourDimension() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingLessThanRankFourDimension(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
{0},{1} {0},{1}
@ -89,7 +92,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testPutSimple() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutSimple(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2); INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2);
INDArray indexes = Nd4j.create(new double[][]{ INDArray indexes = Nd4j.create(new double[][]{
{0},{1} {0},{1}
@ -101,7 +106,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testGetScalar() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetScalar(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(NDArrayIndex.point(1)); INDArray d = arr.get(NDArrayIndex.point(1));
assertTrue(d.isScalar()); assertTrue(d.isScalar());
@ -110,14 +117,18 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testNewAxis() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis(Nd4jBackend backend) {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3}); INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1)); INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
// System.out.println(view); // System.out.println(view);
} }
@Test @Test
public void testVectorIndexing() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE); INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE);
int[] index = new int[] {5, 8, 9}; int[] index = new int[] {5, 8, 9};
INDArray columnsTest = x.getColumns(index); INDArray columnsTest = x.getColumns(index);
@ -129,7 +140,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testGetRowsColumnsMatrix() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowsColumnsMatrix(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6);
INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}}); INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}});
@ -147,7 +160,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test @Test
public void testSlicing() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlicing(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE); INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14}); INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14});
INDArray slice1Test = arange.slice(1); INDArray slice1Test = arange.slice(1);
@ -155,7 +170,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testArangeMul() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArangeMul(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE); INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = NDArrayIndex.interval(0, 2); INDArrayIndex index = NDArrayIndex.interval(0, 2);
INDArray get = arange.get(index, index); INDArray get = arange.get(index, index);
@ -167,7 +184,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testGetIndicesVector() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVector(Nd4jBackend backend) {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1); INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3}); INDArray test = Nd4j.create(new double[] {2, 3});
INDArray result = line.get(NDArrayIndex.point(0), NDArrayIndex.interval(1, 3)); INDArray result = line.get(NDArrayIndex.point(0), NDArrayIndex.interval(1, 3));
@ -175,7 +194,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void testGetIndicesVectorView() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVectorView(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5); INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5);
INDArray column = matrix.getColumn(0).reshape(1,5); INDArray column = matrix.getColumn(0).reshape(1,5);
INDArray test = Nd4j.create(new double[] {6, 11}); INDArray test = Nd4j.create(new double[] {6, 11});
@ -193,7 +214,9 @@ public class IndexingTests extends BaseNd4jTest {
} }
@Test @Test
public void test2dGetPoint(){ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2dGetPoint(Nd4jBackend backend){
INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4); INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4);
for( int i=0; i<3; i++ ){ for( int i=0; i<3; i++ ){
INDArray exp = Nd4j.create(new double[]{i*4+1, i*4+2, i*4+3, i*4+4}); INDArray exp = Nd4j.create(new double[]{i*4+1, i*4+2, i*4+3, i*4+4});
@ -206,7 +229,7 @@ public class IndexingTests extends BaseNd4jTest {
assertEquals(exp, get); assertEquals(exp, get);
} }
for( int i=0; i<4; i++ ){ for( int i = 0; i < 4; i++) {
INDArray exp = Nd4j.create(new double[]{1+i, 5+i, 9+i}); INDArray exp = Nd4j.create(new double[]{1+i, 5+i, 9+i});
INDArray col = arr.getColumn(i); INDArray col = arr.getColumn(i);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(i)); INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(i));

View File

@ -21,10 +21,11 @@
package org.nd4j.linalg.api.indexing; package org.nd4j.linalg.api.indexing;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -49,16 +50,15 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class IndexingTestsC extends BaseNd4jTest { public class IndexingTestsC extends BaseNd4jTestWithBackends {
public IndexingTestsC(Nd4jBackend backend) {
super(backend);
}
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeBounds() { public void testNegativeBounds() {
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5); INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1)); INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
@ -70,7 +70,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion,get); assertEquals(assertion,get);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis() { public void testNewAxis() {
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all()); INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
@ -79,7 +81,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void broadcastBug() { public void broadcastBug() {
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2}); INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0)); final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
@ -90,7 +94,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIntervalsIn3D() { public void testIntervalsIn3D() {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE); INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2); INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
@ -99,7 +105,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSmallInterval() { public void testSmallInterval() {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE); INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2); INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
@ -108,7 +116,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxisAndInterval() { public void testAllWithNewAxisAndInterval() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3); INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
@ -117,7 +127,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion2, get2); assertEquals(assertion2, get2);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxisInMiddle() { public void testAllWithNewAxisInMiddle() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3); INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
@ -126,7 +138,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion2, get2); assertEquals(assertion2, get2);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxis() { public void testAllWithNewAxis() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray get = arr.get(newAxis(), all(), point(1)); INDArray get = arr.get(newAxis(), all(), point(1));
@ -136,7 +150,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingWithMmul() { public void testIndexingWithMmul() {
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1); INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
@ -147,7 +163,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion, c); assertEquals(assertion, c);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointPointInterval() { public void testPointPointInterval() {
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3); INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3)); INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
@ -156,7 +174,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion, get); assertEquals(assertion, get);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIntervalLowerBound() { public void testIntervalLowerBound() {
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2)); INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
@ -167,7 +187,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetPointRowVector() { public void testGetPointRowVector() {
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1); INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
@ -177,7 +199,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2); assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedIndexVector() { public void testSpecifiedIndexVector() {
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4); INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
@ -194,7 +218,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowIndexing() { public void testPutRowIndexing() {
INDArray arr = Nd4j.ones(1, 10); INDArray arr = Nd4j.ones(1, 10);
INDArray row = Nd4j.create(1, 10); INDArray row = Nd4j.create(1, 10);
@ -204,7 +230,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(arr, row); assertEquals(arr, row);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing2() { public void testVectorIndexing2() {
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true)); INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
INDArray assertion = Nd4j.create(new double[] {2, 4}); INDArray assertion = Nd4j.create(new double[] {2, 4});
@ -219,7 +247,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffsetsC() { public void testOffsetsC() {
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
assertEquals(3, NDArrayIndex.offset(arr, 1, 1)); assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
@ -235,7 +265,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexFor() { public void testIndexFor() {
long[] shape = {1, 2}; long[] shape = {1, 2};
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape); INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
@ -244,7 +276,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetScalar() { public void testGetScalar() {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(point(1)); INDArray d = arr.get(point(1));
@ -253,7 +287,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing() { public void testVectorIndexing() {
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1); INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5}); INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
@ -261,14 +297,18 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion, viewTest); assertEquals(assertion, viewTest);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeIndices() { public void testNegativeIndices() {
INDArray test = Nd4j.create(10, 10, 10); INDArray test = Nd4j.create(10, 10, 10);
test.putScalar(new int[] {0, 0, -1}, 1.0); test.putScalar(new int[] {0, 0, -1}, 1.0);
assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber()); assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber());
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndices2d() { public void testGetIndices2d() {
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2); INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
INDArray firstRow = twoByTwo.getRow(0); INDArray firstRow = twoByTwo.getRow(0);
@ -286,7 +326,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement); assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRow() { public void testGetRow() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5); INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
@ -303,7 +345,9 @@ public class IndexingTestsC extends BaseNd4jTest {
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowEdgeCase() { public void testGetRowEdgeCase() {
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1); INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
INDArray get = rowVec.getRow(0); //Returning shape [1,1] INDArray get = rowVec.getRow(0); //Returning shape [1,1]
@ -312,7 +356,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(rowVec, get); assertEquals(rowVec, get);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnEdgeCase() { public void testGetColumnEdgeCase() {
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose(); INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray get = colVec.getColumn(0); //Returning shape [1,1] INDArray get = colVec.getColumn(0); //Returning shape [1,1]
@ -321,7 +367,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(colVec, get); assertEquals(colVec, get);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatColumns() { public void testConcatColumns() {
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE); INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE); INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
@ -330,7 +378,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion, concat); assertEquals(assertion, concat);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVector() { public void testGetIndicesVector() {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1); INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3}); INDArray test = Nd4j.create(new double[] {2, 3});
@ -338,7 +388,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(test, result); assertEquals(test, result);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArangeMul() { public void testArangeMul() {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE); INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = interval(0, 2); INDArrayIndex index = interval(0, 2);
@ -349,7 +401,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion, mul); assertEquals(assertion, mul);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingThorough(){ public void testIndexingThorough(){
long[] fullShape = {3,4,5,6,7}; long[] fullShape = {3,4,5,6,7};
@ -549,7 +603,9 @@ public class IndexingTestsC extends BaseNd4jTest {
return d; return d;
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void debugging(){ public void debugging(){
long[] inShape = {3,4}; long[] inShape = {3,4};
INDArrayIndex[] indexes = new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(1, 2, 4)}; INDArrayIndex[] indexes = new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(1, 2, 4)};

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.resolve; package org.nd4j.linalg.api.indexing.resolve;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -36,15 +37,14 @@ import static org.junit.jupiter.api.Assertions.*;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class NDArrayIndexResolveTests extends BaseNd4jTest {
public NDArrayIndexResolveTests(Nd4jBackend backend) { public class NDArrayIndexResolveTests extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
public void testResolvePoint() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResolvePoint(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
INDArrayIndex[] test = NDArrayIndex.resolve(arr.shape(), NDArrayIndex.point(1)); INDArrayIndex[] test = NDArrayIndex.resolve(arr.shape(), NDArrayIndex.point(1));
INDArrayIndex[] assertion = {NDArrayIndex.point(1), NDArrayIndex.all()}; INDArrayIndex[] assertion = {NDArrayIndex.point(1), NDArrayIndex.all()};
@ -59,6 +59,8 @@ public class NDArrayIndexResolveTests extends BaseNd4jTest {
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResolvePointVector() { public void testResolvePointVector() {
INDArray arr = Nd4j.linspace(1, 4, 4); INDArray arr = Nd4j.linspace(1, 4, 4);
INDArrayIndex[] getPoint = {NDArrayIndex.point(1)}; INDArrayIndex[] getPoint = {NDArrayIndex.point(1)};

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.shape; package org.nd4j.linalg.api.indexing.shape;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.Indices; import org.nd4j.linalg.indexing.Indices;
@ -34,19 +35,15 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class IndexShapeTests extends BaseNd4jTest {
public IndexShapeTests(Nd4jBackend backend) {
super(backend);
}
public class IndexShapeTests extends BaseNd4jTestWithBackends {
private int[] shape = {1, 1, 2, 1, 3, 4, 5, 1}; private int[] shape = {1, 1, 2, 1, 3, 4, 5, 1};
@Test @Test
public void testSinglePoint() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSinglePoint(Nd4jBackend backend) {
/* /*
Assumes all indexes are filled out. Assumes all indexes are filled out.
Test simple general point case Test simple general point case
@ -77,7 +74,9 @@ public class IndexShapeTests extends BaseNd4jTest {
} }
@Test @Test
public void testInterval() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterval(Nd4jBackend backend) {
int[] basicAssertion = {1, 1, 1, 1, 3, 1, 2, 1}; int[] basicAssertion = {1, 1, 1, 1, 3, 1, 2, 1};
INDArrayIndex[] basicTest = {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 1), INDArrayIndex[] basicTest = {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 1),
NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(1, 2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(1, 2),
@ -88,7 +87,9 @@ public class IndexShapeTests extends BaseNd4jTest {
@Test @Test
public void testNewAxis() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis(Nd4jBackend backend) {
//normal prepend //normal prepend
int[] prependAssertion = {1, 1, 1, 1, 2, 1, 3, 4, 5, 1}; int[] prependAssertion = {1, 1, 1, 1, 2, 1, 3, 4, 5, 1};
INDArrayIndex[] prependTest = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(), INDArrayIndex[] prependTest = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(),

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.shape; package org.nd4j.linalg.api.indexing.shape;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.Indices; import org.nd4j.linalg.indexing.Indices;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -33,25 +34,26 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class IndexShapeTests2d extends BaseNd4jTest {
public IndexShapeTests2d(Nd4jBackend backend) { public class IndexShapeTests2d extends BaseNd4jTestWithBackends {
super(backend);
}
private long[] shape = {3, 2}; private long[] shape = {3, 2};
@Test @Test
public void test2dCases() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2dCases(Nd4jBackend backend) {
assertArrayEquals(new long[] {1, 2}, Indices.shape(shape, NDArrayIndex.point(1))); assertArrayEquals(new long[] {1, 2}, Indices.shape(shape, NDArrayIndex.point(1)));
assertArrayEquals(new long[] {3, 1}, assertArrayEquals(new long[] {3, 1},
Indices.shape(shape, NDArrayIndex.all(), NDArrayIndex.point(1))); Indices.shape(shape, NDArrayIndex.all(), NDArrayIndex.point(1)));
} }
@Test @Test
public void testNewAxis2d() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis2d(Nd4jBackend backend) {
assertArrayEquals(new long[] {1, 3, 2}, Indices.shape(shape, assertArrayEquals(new long[] {1, 3, 2}, Indices.shape(shape,
NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all())); NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all()));
assertArrayEquals(new long[] {3, 1, 2}, Indices.shape(shape, assertArrayEquals(new long[] {3, 1, 2}, Indices.shape(shape,

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.iterator;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,15 +34,14 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class NDIndexIteratorTest extends BaseNd4jTest {
public NDIndexIteratorTest(Nd4jBackend backend) { public class NDIndexIteratorTest extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
public void testIterate() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIterate(Nd4jBackend backend) {
val shapeIter = new NdIndexIterator(2, 2); val shapeIter = new NdIndexIterator(2, 2);
val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},}; val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},};

View File

@ -28,9 +28,10 @@ import org.apache.commons.lang3.ArrayUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -45,18 +46,15 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class TestNdArrReadWriteTxt extends BaseNd4jTest {
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
public TestNdArrReadWriteTxt(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void compareAfterWrite(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
int [] ranksToCheck = new int[] {0,1,2,3,4}; int [] ranksToCheck = new int[] {0,1,2,3,4};
for (int i=0; i<ranksToCheck.length;i++) { for (int i = 0; i < ranksToCheck.length; i++) {
// log.info("Checking read write arrays with rank " + ranksToCheck[i]); // log.info("Checking read write arrays with rank " + ranksToCheck[i]);
compareArrays(ranksToCheck[i],ordering(), testDir); compareArrays(ranksToCheck[i],ordering(), testDir);
} }
@ -84,7 +82,9 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTest {
} }
@Test @Test
public void testNd4jReadWriteText(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jReadWriteText(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
int count = 0; int count = 0;

View File

@ -25,9 +25,10 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import java.nio.file.Path; import java.nio.file.Path;
@ -35,17 +36,14 @@ import java.nio.file.Path;
import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays; import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class TestNdArrReadWriteTxtC extends BaseNd4jTest {
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
public TestNdArrReadWriteTxtC(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void compareAfterWrite(@TempDir Path testDir) throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4}; int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
for (int i = 0; i < ranksToCheck.length; i++) { for (int i = 0; i < ranksToCheck.length; i++) {
log.info("Checking read write arrays with rank " + ranksToCheck[i]); log.info("Checking read write arrays with rank " + ranksToCheck[i]);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.ndarray; package org.nd4j.linalg.api.ndarray;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
@ -32,16 +33,14 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestSerialization extends BaseNd4jTest {
public TestSerialization(Nd4jBackend backend) { public class TestSerialization extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10); INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10); INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -71,7 +70,9 @@ public class TestSerialization extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationFullArrayJava() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10); INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10); INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -102,7 +103,9 @@ public class TestSerialization extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationOnViewsNd4jWriteRead() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10); INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10); INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -138,7 +141,9 @@ public class TestSerialization extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationOnViewsJava() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10); INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10); INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);

View File

@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -39,23 +40,21 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class TestSerializationDoubleToFloat extends BaseNd4jTest {
DataType initialType; public class TestSerializationDoubleToFloat extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
public TestSerializationDoubleToFloat(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@AfterEach @AfterEach
public void after() { public void after() {
DataTypeUtil.setDTypeForContext(this.initialType); DataTypeUtil.setDTypeForContext(this.initialType);
} }
@Test @Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 4; int length = 4;
//WRITE OUT A DOUBLE ARRAY //WRITE OUT A DOUBLE ARRAY
@ -93,7 +92,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationFullArrayJava() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE); DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
@ -123,7 +124,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationOnViewsNd4jWriteRead() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE); DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
@ -153,7 +156,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
} }
@Test @Test
public void testSerializationOnViewsJava() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE); DataTypeUtil.setDTypeForContext(DataType.DOUBLE);

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.ndarray;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -37,23 +38,21 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class TestSerializationFloatToDouble extends BaseNd4jTest {
DataType initialType; public class TestSerializationFloatToDouble extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
public TestSerializationFloatToDouble(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@AfterEach @AfterEach
public void after() { public void after() {
Nd4j.setDataType(this.initialType); Nd4j.setDataType(this.initialType);
} }
@Test @Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100; int length = 100;
//WRITE OUT A FLOAT ARRAY //WRITE OUT A FLOAT ARRAY
@ -85,7 +84,9 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
assertTrue(Transforms.abs(arr1.sub(arr2).div(arr1)).maxNumber().doubleValue() < 0.01); assertTrue(Transforms.abs(arr1.sub(arr2).div(arr1)).maxNumber().doubleValue() < 0.01);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava() throws Exception { public void testSerializationFullArrayJava() throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);
@ -116,7 +117,9 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
assertTrue(Transforms.abs(arr1.sub(arr2).div(arr1)).maxNumber().doubleValue() < 0.01); assertTrue(Transforms.abs(arr1.sub(arr2).div(arr1)).maxNumber().doubleValue() < 0.01);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead() throws Exception { public void testSerializationOnViewsNd4jWriteRead() throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);
@ -146,7 +149,9 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
assertTrue(Transforms.abs(sub1.sub(arr2).div(sub1)).maxNumber().doubleValue() < 0.01); assertTrue(Transforms.abs(sub1.sub(arr2).div(sub1)).maxNumber().doubleValue() < 0.01);
} }
@Test @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava() throws Exception { public void testSerializationOnViewsJava() throws Exception {
int length = 100; int length = 100;
Nd4j.create(1); Nd4j.create(1);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.rng; package org.nd4j.linalg.api.rng;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,14 +34,13 @@ import static org.junit.jupiter.api.Assertions.*;
/** /**
* @author Adam Gibson * @author Adam Gibson
*/ */
@RunWith(Parameterized.class)
public class RngTests extends BaseNd4jTest { public class RngTests extends BaseNd4jTestWithBackends {
public RngTests(Nd4jBackend backend) {
super(backend);
}
@Test @Test
public void testRngConstitency() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRngConstitency(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
INDArray arr = Nd4j.rand(1, 5); INDArray arr = Nd4j.rand(1, 5);
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
@ -49,7 +49,9 @@ public class RngTests extends BaseNd4jTest {
} }
@Test @Test
public void testRandomWithOrder() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomWithOrder(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -105,7 +107,9 @@ public class RngTests extends BaseNd4jTest {
} }
@Test @Test
public void testRandomBinomial() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomBinomial(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
//silly tests. Just increasing the usage for randomBinomial to stop compiler warnings. //silly tests. Just increasing the usage for randomBinomial to stop compiler warnings.
INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3); INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3);

View File

@ -23,9 +23,10 @@ package org.nd4j.linalg.api.string;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.Assert; import org.junit.Assert;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,22 +36,23 @@ import org.nd4j.linalg.string.NDArrayStrings;
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j @Slf4j
@RunWith(Parameterized.class)
public class TestFormatting extends BaseNd4jTest {
public TestFormatting(Nd4jBackend backend) { public class TestFormatting extends BaseNd4jTestWithBackends {
super(backend);
}
@Test @Test
public void testTwoByTwo() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTwoByTwo(Nd4jBackend backend) {
INDArray arr = Nd4j.create(2, 2, 2, 2); INDArray arr = Nd4j.create(2, 2, 2, 2);
System.out.println(new NDArrayStrings().format(arr)); System.out.println(new NDArrayStrings().format(arr));
} }
@Test @Test
public void testNd4jArrayString() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jArrayString(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[]{1f, 20000000f, 40.838383f, 3f}, new int[]{2, 2}); INDArray arr = Nd4j.create(new float[]{1f, 20000000f, 40.838383f, 3f}, new int[]{2, 2});
@ -71,7 +73,9 @@ public class TestFormatting extends BaseNd4jTest {
} }
@Test @Test
public void testRange() { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRange(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[][]{ INDArray arr = Nd4j.create(new double[][]{
{-1,0,1,0}, {-1,0,1,0},
{-0.1, 0.1, -10, 10}, {-0.1, 0.1, -10, 10},

Some files were not shown because too many files have changed in this diff Show More