Migrate parameterized tests to junit 5

master
agibsonccc 2021-03-16 22:08:35 +09:00
parent 82bdcc21d2
commit 3c6014271e
260 changed files with 10787 additions and 6270 deletions

View File

@ -37,8 +37,10 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -47,13 +49,14 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.conf.ConvolutionMode.Same;
import static org.deeplearning4j.nn.conf.ConvolutionMode.Truncate;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class)
@DisplayName("Cnn Gradient Check Test")
class CNNGradientCheckTest extends BaseDL4JTest {
@ -71,15 +74,10 @@ class CNNGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
private CNN2DFormat format;
public CNNGradientCheckTest(CNN2DFormat format) {
this.format = format;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params() {
return CNN2DFormat.values();
public static Stream<Arguments> params() {
return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of);
}
@Override
@ -89,9 +87,11 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Gradient CNNMLN")
void testGradientCNNMLN() {
@ParameterizedTest
@MethodSource("#params")
public void testGradientCNNMLN(CNN2DFormat format) {
if (// Only test NCHW due to flat input format...
this.format != CNN2DFormat.NCHW)
format != CNN2DFormat.NCHW)
return;
// Parameterized test, testing combinations of:
// (a) activation function
@ -146,9 +146,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Gradient CNNL 1 L 2 MLN")
void testGradientCNNL1L2MLN() {
void testGradientCNNL1L2MLN(CNN2DFormat format) {
if (// Only test NCHW due to flat input format...
this.format != CNN2DFormat.NCHW)
format != CNN2DFormat.NCHW)
return;
// Parameterized test, testing combinations of:
// (a) activation function
@ -245,7 +245,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn With Space To Batch")
void testCnnWithSpaceToBatch() {
@ParameterizedTest
@MethodSource("#params")
public void testCnnWithSpaceToBatch(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 4;
int[] minibatchSizes = { 2, 4 };
@ -289,7 +291,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn With Upsampling")
void testCnnWithUpsampling() {
@ParameterizedTest
@MethodSource("#params")
void testCnnWithUpsampling(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 4;
int[] minibatchSizes = { 1, 3 };
@ -323,7 +327,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn With Subsampling")
void testCnnWithSubsampling() {
@ParameterizedTest
@MethodSource("#params")
void testCnnWithSubsampling(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 4;
int[] minibatchSizes = { 1, 3 };
@ -365,7 +371,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn With Subsampling V 2")
void testCnnWithSubsamplingV2() {
@ParameterizedTest
@MethodSource("#params")
void testCnnWithSubsamplingV2(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 4;
int[] minibatchSizes = { 1, 3 };
@ -403,7 +411,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Locally Connected 2 D")
void testCnnLocallyConnected2D() {
@ParameterizedTest
@MethodSource("#params")
void testCnnLocallyConnected2D(CNN2DFormat format) {
int nOut = 3;
int width = 5;
int height = 5;
@ -433,7 +443,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Multi Layer")
void testCnnMultiLayer() {
@ParameterizedTest
@MethodSource("#params")
void testCnnMultiLayer(CNN2DFormat format) {
int nOut = 2;
int[] minibatchSizes = { 1, 2, 5 };
int width = 5;
@ -473,7 +485,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Same Padding Mode")
void testCnnSamePaddingMode() {
@ParameterizedTest
@MethodSource("#params")
void testCnnSamePaddingMode(CNN2DFormat format) {
int nOut = 2;
int[] minibatchSizes = { 1, 3, 3, 2, 1, 2 };
// Same padding mode: insensitive to exact input size...
@ -507,7 +521,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Same Padding Mode Strided")
void testCnnSamePaddingModeStrided() {
@ParameterizedTest
@MethodSource("#params")
void testCnnSamePaddingModeStrided(CNN2DFormat format) {
int nOut = 2;
int[] minibatchSizes = { 1, 3 };
int width = 16;
@ -550,7 +566,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Zero Padding Layer")
void testCnnZeroPaddingLayer() {
@ParameterizedTest
@MethodSource("#params")
void testCnnZeroPaddingLayer(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 4;
int width = 6;
@ -596,7 +614,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Deconvolution 2 D")
void testDeconvolution2D() {
@ParameterizedTest
@MethodSource("#params")
void testDeconvolution2D(CNN2DFormat format) {
int nOut = 2;
int[] minibatchSizes = new int[] { 1, 3, 3, 1, 3 };
int[] kernelSizes = new int[] { 1, 1, 1, 3, 3 };
@ -641,7 +661,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Separable Conv 2 D")
void testSeparableConv2D() {
@ParameterizedTest
@MethodSource("#params")
void testSeparableConv2D(CNN2DFormat format) {
int nOut = 2;
int width = 6;
int height = 6;
@ -686,7 +708,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cnn Dilated")
void testCnnDilated() {
@ParameterizedTest
@MethodSource("#params")
void testCnnDilated(CNN2DFormat format) {
int nOut = 2;
int minibatchSize = 2;
int width = 8;
@ -736,7 +760,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Cropping 2 D Layer")
void testCropping2DLayer() {
@ParameterizedTest
@MethodSource("#params")
void testCropping2DLayer(CNN2DFormat format) {
Nd4j.getRandom().setSeed(12345);
int nOut = 2;
int width = 12;
@ -780,7 +806,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
@Test
@DisplayName("Test Depthwise Conv 2 D")
void testDepthwiseConv2D() {
@ParameterizedTest
@MethodSource("#params")
void testDepthwiseConv2D(CNN2DFormat format) {
int nIn = 3;
int depthMultiplier = 2;
int nOut = nIn * depthMultiplier;

View File

@ -39,8 +39,10 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -55,26 +57,22 @@ import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class YoloGradientCheckTests extends BaseDL4JTest {
static {
Nd4j.setDataType(DataType.DOUBLE);
}
private CNN2DFormat format;
public YoloGradientCheckTests(CNN2DFormat format){
this.format = format;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return CNN2DFormat.values();
}
public static Stream<Arguments> params() {
return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of);
}
@Override
public long getTimeoutMilliseconds() {
@ -82,7 +80,9 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
}
@Test
public void testYoloOutputLayer() {
@ParameterizedTest
@MethodSource("#params")
public void testYoloOutputLayer(CNN2DFormat format) {
int depthIn = 2;
int c = 3;
int b = 3;
@ -159,13 +159,13 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
}
}
private static INDArray yoloLabels(int mb, int c, int h, int w){
private static INDArray yoloLabels(int mb, int c, int h, int w) {
int labelDepth = 4 + c;
INDArray labels = Nd4j.zeros(mb, labelDepth, h, w);
//put 1 object per minibatch, at positions (0,0), (1,1) etc.
//Positions for label boxes: (1,1) to (2,2), (2,2) to (4,4) etc
for( int i=0; i<mb; i++ ){
for( int i = 0; i < mb; i++) {
//Class labels
labels.putScalar(i, 4 + i%c, i%h, i%w, 1);
@ -181,7 +181,9 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
@Test
public void yoloGradientCheckRealData(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("#params")
public void yoloGradientCheckRealData(@TempDir Path testDir,CNN2DFormat format) throws Exception {
Nd4j.getRandom().setSeed(12345);
InputStream is1 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2007_009346.jpg").getInputStream();
InputStream is2 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2007_009346.xml").getInputStream();

View File

@ -39,8 +39,10 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -48,24 +50,19 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class ConvDataFormatTests extends BaseDL4JTest {
private final DataType dataType;
public ConvDataFormatTests(DataType dataType){
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{0}")
public static Object[] params(){
return new DataType[]{DataType.FLOAT, DataType.DOUBLE};
public static Stream<Arguments> params(){
return Arrays.asList(new DataType[]{DataType.FLOAT, DataType.DOUBLE}).stream().map(Arguments::of);
}
@Override
@ -74,7 +71,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testConv2d() {
@MethodSource("#params")
@ParameterizedTest
public void testConv2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -83,15 +82,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getConv2dNet(CNN2DFormat.NHWC, false, cm))
.net1(getConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -107,7 +106,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testSubsampling2d() {
@MethodSource("#params")
@ParameterizedTest
public void testSubsampling2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -116,15 +117,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm))
.net1(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -140,7 +141,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testDepthwiseConv2d() {
@MethodSource("#params")
@ParameterizedTest
public void testDepthwiseConv2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -149,15 +152,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm))
.net1(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -173,7 +176,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testSeparableConv2d() {
@MethodSource("#params")
@ParameterizedTest
public void testSeparableConv2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -182,15 +187,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm))
.net1(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -206,7 +211,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testDeconv2d() {
@MethodSource("#params")
@ParameterizedTest
public void testDeconv2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -215,15 +222,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm))
.net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm))
.net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm))
.net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm))
.net1(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -239,7 +246,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testLRN() {
@MethodSource("#params")
@ParameterizedTest
public void testLRN(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -248,15 +257,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLrnLayer(CNN2DFormat.NCHW, true, cm))
.net2(getLrnLayer(CNN2DFormat.NCHW, false, cm))
.net3(getLrnLayer(CNN2DFormat.NHWC, true, cm))
.net4(getLrnLayer(CNN2DFormat.NHWC, false, cm))
.net1(getLrnLayer(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getLrnLayer(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getLrnLayer(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getLrnLayer(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -272,7 +281,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testZeroPaddingLayer(){
@MethodSource("#params")
@ParameterizedTest
public void testZeroPaddingLayer(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -280,15 +291,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getZeroPaddingNet(CNN2DFormat.NCHW, true))
.net2(getZeroPaddingNet(CNN2DFormat.NCHW, false))
.net3(getZeroPaddingNet(CNN2DFormat.NHWC, true))
.net4(getZeroPaddingNet(CNN2DFormat.NHWC, false))
.net1(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, true))
.net2(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, false))
.net3(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, true))
.net4(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -303,7 +314,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testCropping2DLayer(){
@MethodSource("#params")
@ParameterizedTest
public void testCropping2DLayer(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -311,15 +324,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getCropping2dNet(CNN2DFormat.NCHW, true))
.net2(getCropping2dNet(CNN2DFormat.NCHW, false))
.net3(getCropping2dNet(CNN2DFormat.NHWC, true))
.net4(getCropping2dNet(CNN2DFormat.NHWC, false))
.net1(getCropping2dNet(dataType,CNN2DFormat.NCHW, true))
.net2(getCropping2dNet(dataType,CNN2DFormat.NCHW, false))
.net3(getCropping2dNet(dataType,CNN2DFormat.NHWC, true))
.net4(getCropping2dNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -334,7 +347,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testUpsampling2d(){
@MethodSource("#params")
@ParameterizedTest
public void testUpsampling2d(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -342,15 +357,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getUpsamplingNet(CNN2DFormat.NCHW, true))
.net2(getUpsamplingNet(CNN2DFormat.NCHW, false))
.net3(getUpsamplingNet(CNN2DFormat.NHWC, true))
.net4(getUpsamplingNet(CNN2DFormat.NHWC, false))
.net1(getUpsamplingNet(dataType,CNN2DFormat.NCHW, true))
.net2(getUpsamplingNet(dataType,CNN2DFormat.NCHW, false))
.net3(getUpsamplingNet(dataType,CNN2DFormat.NHWC, true))
.net4(getUpsamplingNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -365,7 +380,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testBatchNormNet(){
@MethodSource("#params")
@ParameterizedTest
public void testBatchNormNet(DataType dataType) {
try {
for(boolean useLogStd : new boolean[]{true, false}) {
for (boolean helpers : new boolean[]{false, true}) {
@ -374,15 +391,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std");
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true))
.net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false))
.net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true))
.net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false))
.net1(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, true))
.net2(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, false))
.net3(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, true))
.net4(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -398,7 +415,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testCnnLossLayer() {
@MethodSource("#params")
@ParameterizedTest
public void testCnnLossLayer(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -406,8 +425,8 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labelsNHWC = TestUtils.randomOneHot(dataType,2*6*6, 3);
labelsNHWC = labelsNHWC.reshape(2,6,6,3);
INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup();
@ -434,7 +453,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testSpaceToDepthNet(){
@MethodSource("#params")
@ParameterizedTest
public void testSpaceToDepthNet(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -442,15 +463,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true))
.net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false))
.net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true))
.net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false))
.net1(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, true))
.net2(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, false))
.net3(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, true))
.net4(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -465,7 +486,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testSpaceToBatchNet(){
@MethodSource("#params")
@ParameterizedTest
public void testSpaceToBatchNet(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
Nd4j.getRandom().setSeed(12345);
@ -473,15 +496,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers" : "No helpers";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 16, 16);
INDArray labels = TestUtils.randomOneHot(8, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true))
.net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false))
.net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true))
.net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false))
.net1(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, true))
.net2(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, false))
.net3(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, true))
.net4(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -496,7 +519,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
@Test
public void testLocallyConnected() {
@MethodSource("#params")
@ParameterizedTest
public void testLocallyConnected(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
@ -505,15 +530,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm))
.net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm))
.net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm))
.net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm))
.net1(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, true, cm))
.net2(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, false, cm))
.net3(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, true, cm))
.net4(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, false, cm))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -530,7 +555,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
@Test
public void testGlobalPooling() {
@MethodSource("#params")
@ParameterizedTest
public void testGlobalPooling(DataType dataType) {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (PoolingType pt : PoolingType.values()) {
@ -539,15 +566,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false))
.net1(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
@ -562,9 +589,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new ConvolutionLayer.Builder()
return getNetWithLayer(dataType,new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -573,7 +600,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new ConvolutionLayer.Builder()
return getNetWithLayer(dataType,new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -583,16 +610,16 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getSubsampling2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new SubsamplingLayer.Builder()
return getNetWithLayer(dataType,new SubsamplingLayer.Builder()
.kernelSize(2, 2)
.stride(1, 1)
.dataFormat(format)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new SubsamplingLayer.Builder()
return getNetWithLayer(dataType,new SubsamplingLayer.Builder()
.kernelSize(2, 2)
.stride(1, 1)
.helperAllowFallback(false)
@ -600,9 +627,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getSeparableConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new SeparableConvolution2D.Builder()
return getNetWithLayer(dataType,new SeparableConvolution2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -611,7 +638,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new SeparableConvolution2D.Builder()
return getNetWithLayer(dataType,new SeparableConvolution2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -621,9 +648,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getDepthwiseConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder()
.depthMultiplier(2)
.kernelSize(3, 3)
.stride(2, 2)
@ -633,7 +660,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder()
.depthMultiplier(2)
.kernelSize(3, 3)
.stride(2, 2)
@ -644,59 +671,59 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getLrnLayer(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new LocalResponseNormalization.Builder()
return getNetWithLayer(dataType,new LocalResponseNormalization.Builder()
.dataFormat(format)
.helperAllowFallback(false)
.build(), format, cm, null);
} else {
return getNetWithLayer(new LocalResponseNormalization.Builder()
return getNetWithLayer(dataType,new LocalResponseNormalization.Builder()
.helperAllowFallback(false)
.build(), format, cm, null);
}
}
private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getZeroPaddingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2)
return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(),
return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2).build(),
format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getCropping2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new Cropping2D.Builder(2,2)
return getNetWithLayer(dataType,new Cropping2D.Builder(2,2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new Cropping2D.Builder(2,2)
return getNetWithLayer(dataType,new Cropping2D.Builder(2,2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getUpsamplingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new Upsampling2D.Builder(2)
return getNetWithLayer(dataType,new Upsampling2D.Builder(2)
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new Upsampling2D.Builder(2)
return getNetWithLayer(dataType,new Upsampling2D.Builder(2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getDeconv2DNet2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH)
.kernelSize(2,2)
.dataFormat(format)
.stride(2,2)
.build(), format, cm, null);
} else {
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2)
.activation(Activation.TANH)
.kernelSize(2,2)
.dataFormat(format)
@ -705,50 +732,50 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getBatchNormNet(DataType dataType,boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new BatchNormalization.Builder()
return getNetWithLayer(dataType,new BatchNormalization.Builder()
.useLogStd(logStdev)
.dataFormat(format)
.helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new BatchNormalization.Builder()
return getNetWithLayer(dataType,new BatchNormalization.Builder()
.useLogStd(logStdev)
.helperAllowFallback(false)
.nOut(3).build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getSpaceToDepthNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToDepthLayer.Builder()
return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder()
.blocks(2)
.dataFormat(format)
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new SpaceToDepthLayer.Builder()
return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder()
.blocks(2)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) {
private MultiLayerNetwork getSpaceToBatchNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToBatchLayer.Builder()
return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder()
.blocks(2, 2)
.dataFormat(format)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
} else {
return getNetWithLayer(new SpaceToBatchLayer.Builder()
return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder()
.blocks(2, 2)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
}
}
private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
private MultiLayerNetwork getLocallyConnectedNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new LocallyConnected2D.Builder()
return getNetWithLayer(dataType,new LocallyConnected2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -756,7 +783,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.nOut(3)
.build(), format, cm, null);
} else {
return getNetWithLayer(new LocallyConnected2D.Builder()
return getNetWithLayer(dataType,new LocallyConnected2D.Builder()
.kernelSize(3, 3)
.stride(2, 2)
.activation(Activation.TANH)
@ -765,9 +792,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
private MultiLayerNetwork getNetWithLayer(DataType dataType,Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.dataType(this.dataType)
.dataType(dataType)
.seed(12345)
.convolutionMode(cm)
.list()
@ -794,13 +821,13 @@ public class ConvDataFormatTests extends BaseDL4JTest {
return net;
}
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
private MultiLayerNetwork getGlobalPoolingNet(DataType dataType,CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt)
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt)
.build(), format, ConvolutionMode.Same, null);
}
}

View File

@ -45,8 +45,11 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -61,30 +64,29 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
@RunWith(Parameterized.class)
@DisplayName("Bidirectional Test")
class BidirectionalTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public BidirectionalTest(RNNFormat rnnDataFormat) {
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params() {
return RNNFormat.values();
public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Test
@DisplayName("Compare Implementations")
void compareImplementations() {
@ParameterizedTest
@MethodSource("#params")
void compareImplementations(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm);
// Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params
@ -147,9 +149,11 @@ class BidirectionalTest extends BaseDL4JTest {
}
}
@Test
@DisplayName("Compare Implementations Comp Graph")
void compareImplementationsCompGraph() {
@Test
@ParameterizedTest
@MethodSource("#params")
void compareImplementationsCompGraph(RNNFormat rnnFormat) {
// for(WorkspaceMode wsm : WorkspaceMode.values()) {
for (WorkspaceMode wsm : new WorkspaceMode[] { WorkspaceMode.NONE, WorkspaceMode.ENABLED }) {
log.info("*** Starting workspace mode: " + wsm);
@ -187,8 +191,8 @@ class BidirectionalTest extends BaseDL4JTest {
Gradient g2 = net2.gradient();
assertEquals(g1.gradient(), g2.gradient());
// Ensure updates are equal:
ComputationGraphUpdater u1 = (ComputationGraphUpdater) net1.getUpdater();
ComputationGraphUpdater u2 = (ComputationGraphUpdater) net2.getUpdater();
ComputationGraphUpdater u1 = net1.getUpdater();
ComputationGraphUpdater u2 = net2.getUpdater();
assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray());
u1.update(g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
u2.update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
@ -205,7 +209,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test
@DisplayName("Test Serialization")
void testSerialization() throws Exception {
@ParameterizedTest
@MethodSource("#params")
void testSerialization(RNNFormat rnnDataFormat) throws Exception {
for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345);
@ -242,7 +248,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test
@DisplayName("Test Serialization Comp Graph")
void testSerializationCompGraph() throws Exception {
@ParameterizedTest
@MethodSource("#params")
void testSerializationCompGraph(RNNFormat rnnDataFormat) throws Exception {
for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345);
@ -277,7 +285,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test
@DisplayName("Test Simple Bidirectional")
void testSimpleBidirectional() {
@ParameterizedTest
@MethodSource("#params")
public void testSimpleBidirectional(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345);
@ -362,7 +372,9 @@ class BidirectionalTest extends BaseDL4JTest {
@Test
@DisplayName("Test Simple Bidirectional Comp Graph")
void testSimpleBidirectionalCompGraph() {
@ParameterizedTest
@MethodSource("#params")
void testSimpleBidirectionalCompGraph(RNNFormat rnnDataFormat) {
for (WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345);

View File

@ -19,7 +19,6 @@
*/
package org.deeplearning4j.nn.layers.recurrent;
import junit.framework.TestCase;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -34,9 +33,12 @@ import org.deeplearning4j.nn.params.GravesBidirectionalLSTMParamInitializer;
import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -44,31 +46,29 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.primitives.Pair;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class)
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Graves Bidirectional LSTM Test")
class GravesBidirectionalLSTMTest extends BaseDL4JTest {
private double score = 0.0;
private RNNFormat rnnDataFormat;
public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat) {
this.rnnDataFormat = rnnDataFormat;
public static Stream<Arguments> params(){
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters
public static Object[] params() {
return RNNFormat.values();
}
@Test
@DisplayName("Test Bidirectional LSTM Graves Forward Basic")
void testBidirectionalLSTMGravesForwardBasic() {
@MethodSource("#params")
@ParameterizedTest
void testBidirectionalLSTMGravesForwardBasic(RNNFormat rnnDataFormat) {
// Very basic test of forward prop. of LSTM layer with a time series.
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
int nIn = 13;
@ -110,19 +110,21 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test
@DisplayName("Test Bidirectional LSTM Graves Backward Basic")
void testBidirectionalLSTMGravesBackwardBasic() {
@MethodSource("#params")
@ParameterizedTest
void testBidirectionalLSTMGravesBackwardBasic(RNNFormat rnnDataFormat) {
// Very basic test of backprop for mini-batch + time series
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
testGravesBackwardBasicHelper(13, 3, 17, 10, 7);
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 7);
// Edge case: miniBatchSize = 1
testGravesBackwardBasicHelper(13, 3, 17, 1, 7);
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 7);
// Edge case: timeSeriesLength = 1
testGravesBackwardBasicHelper(13, 3, 17, 10, 1);
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 1);
// Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
testGravesBackwardBasicHelper(13, 3, 17, 1, 1);
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 1);
}
private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) {
private void testGravesBackwardBasicHelper(RNNFormat rnnDataFormat,int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) {
INDArray inputData = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.ones(miniBatchSize, nIn, timeSeriesLength) : Nd4j.ones(miniBatchSize, timeSeriesLength, nIn);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build();
long numParams = conf.getLayer().initializer().numParams(conf);
@ -204,7 +206,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test
@DisplayName("Test Get Set Parmas")
void testGetSetParmas() {
@MethodSource("#params")
@ParameterizedTest
void testGetSetParmas(RNNFormat rnnDataFormat) {
final int nIn = 2;
final int layerSize = 3;
final int miniBatchSize = 2;
@ -224,7 +228,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test
@DisplayName("Test Simple Forwards And Backwards Activation")
void testSimpleForwardsAndBackwardsActivation() {
@MethodSource("#params")
@ParameterizedTest
void testSimpleForwardsAndBackwardsActivation(RNNFormat rnnDataFormat) {
final int nIn = 2;
final int layerSize = 3;
final int miniBatchSize = 1;
@ -342,7 +348,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
@Test
@DisplayName("Test Gate Activation Fns Sanity Check")
void testGateActivationFnsSanityCheck() {
@MethodSource("#params")
@ParameterizedTest
void testGateActivationFnsSanityCheck(RNNFormat rnnDataFormat) {
for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);

View File

@ -30,36 +30,35 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.Arrays;
import java.util.Collections;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(Parameterized.class)
@DisplayName("Mask Zero Layer Test")
class MaskZeroLayerTest extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public MaskZeroLayerTest(RNNFormat rnnDataFormat) {
this.rnnDataFormat = rnnDataFormat;
public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Parameterized.Parameters
public static Object[] params() {
return RNNFormat.values();
}
@Test
@DisplayName("Activate")
void activate() {
@Test
@ParameterizedTest
@MethodSource("#params")
void activate(RNNFormat rnnDataFormat) {
// GIVEN two examples where some of the timesteps are zero.
INDArray ex1 = Nd4j.create(new double[][] { new double[] { 0, 3, 5 }, new double[] { 0, 0, 2 } });
INDArray ex2 = Nd4j.create(new double[][] { new double[] { 0, 0, 2 }, new double[] { 0, 0, 2 } });
@ -95,9 +94,12 @@ class MaskZeroLayerTest extends BaseDL4JTest {
assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6);
}
@Test
@DisplayName("Test Serialization")
void testSerialization() {
@Test
@ParameterizedTest
@MethodSource("#params")
void testSerialization(RNNFormat rnnDataFormat) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder().setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

View File

@ -40,8 +40,10 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -51,30 +53,31 @@ import org.nd4j.common.primitives.Pair;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@AllArgsConstructor
public class RnnDataFormatTests extends BaseDL4JTest {
private boolean helpers;
private boolean lastTimeStep;
private boolean maskZeros;
@Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}")
public static List params(){
public static Stream<Arguments> params() {
List<Object[]> ret = new ArrayList<>();
for (boolean helpers: new boolean[]{true, false})
for (boolean lastTimeStep: new boolean[]{true, false})
for (boolean maskZero: new boolean[]{true, false})
ret.add(new Object[]{helpers, lastTimeStep, maskZero});
return ret;
return ret.stream().map(Arguments::of);
}
@Test
public void testSimpleRnn() {
@MethodSource("#params")
@ParameterizedTest
public void testSimpleRnn(boolean helpers,
boolean lastTimeStep,
boolean maskZeros
) {
try {
Nd4j.getRandom().setSeed(12345);
@ -107,7 +110,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
}
@Test
public void testLSTM() {
@ParameterizedTest
@MethodSource("#params")
public void testLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try {
Nd4j.getRandom().setSeed(12345);
@ -141,7 +148,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
@Test
public void testGraveLSTM() {
@MethodSource("#params")
@ParameterizedTest
public void testGraveLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try {
Nd4j.getRandom().setSeed(12345);
@ -175,7 +186,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
@Test
public void testGraveBiLSTM() {
@MethodSource("#params")
@ParameterizedTest
public void testGraveBiLSTM(boolean helpers,
boolean lastTimeStep,
boolean maskZeros) {
try {
Nd4j.getRandom().setSeed(12345);

View File

@ -34,14 +34,20 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.learning.config.AdaGrad;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
import static org.deeplearning4j.nn.weights.WeightInit.XAVIER_UNIFORM;
import static org.junit.jupiter.api.Assertions.*;
@ -50,20 +56,16 @@ import static org.nd4j.linalg.activations.Activation.TANH;
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
@RunWith(Parameterized.class)
public class TestLastTimeStepLayer extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestLastTimeStepLayer(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters(name="{0}")
public static Object[] params(){
return RNNFormat.values();
public static Stream<Arguments> params(){
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Test
public void testLastTimeStepVertex() {
@ParameterizedTest
@MethodSource("#params")
public void testLastTimeStepVertex(RNNFormat rnnDataFormat) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
@ -126,7 +128,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
}
@Test
public void testMaskingAndAllMasked(){
@ParameterizedTest
@MethodSource("#params")
public void testMaskingAndAllMasked(RNNFormat rnnDataFormat) {
ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder()
.optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT)
.weightInit(XAVIER_UNIFORM)

View File

@ -36,8 +36,11 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.enums.RnnDataFormat;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -49,25 +52,23 @@ import org.nd4j.common.primitives.Pair;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class TestRnnLayers extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestRnnLayers(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
public static Stream<Arguments> params(){
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Test
public void testTimeStepIs3Dimensional() {
@ParameterizedTest
@MethodSource("#params")
public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat) {
int nIn = 12;
int nOut = 3;
@ -117,7 +118,9 @@ public class TestRnnLayers extends BaseDL4JTest {
}
@Test
public void testDropoutRecurrentLayers(){
@ParameterizedTest
@MethodSource("#params")
public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat){
Nd4j.getRandom().setSeed(12345);
String[] layerTypes = new String[]{"graves", "lstm", "simple"};
@ -215,9 +218,11 @@ public class TestRnnLayers extends BaseDL4JTest {
}
@Test
public void testMismatchedInputLabelLength(){
@ParameterizedTest
@MethodSource("#params")
public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat){
for( int i=0; i<2; i++ ){
for( int i = 0; i < 2; i++) {
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()

View File

@ -29,8 +29,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,25 +40,25 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
@RunWith(Parameterized.class)
public class TestSimpleRnn extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestSimpleRnn(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
public static Stream<Arguments> params() {
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Test
public void testSimpleRnn(){
@ParameterizedTest
@MethodSource("#params")
public void testSimpleRnn(RNNFormat rnnDataFormat) {
Nd4j.getRandom().setSeed(12345);
int m = 3;
@ -125,7 +127,9 @@ public class TestSimpleRnn extends BaseDL4JTest {
}
@Test
public void testBiasInit(){
@ParameterizedTest
@MethodSource("#params")
public void testBiasInit(RNNFormat rnnDataFormat) {
Nd4j.getRandom().setSeed(12345);
int nIn = 5;
int layerSize = 6;

View File

@ -37,8 +37,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -47,22 +49,22 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestTimeDistributed extends BaseDL4JTest {
private RNNFormat rnnDataFormat;
public TestTimeDistributed(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
}
@Parameterized.Parameters
public static Object[] params(){
return RNNFormat.values();
public static Stream<Arguments> params(){
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
}
@Test
public void testTimeDistributed(){
@ParameterizedTest
@MethodSource("#params")
public void testTimeDistributed(RNNFormat rnnDataFormat){
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
@ -133,10 +135,12 @@ public class TestTimeDistributed extends BaseDL4JTest {
@Test
public void testTimeDistributedDense(){
@MethodSource("#params")
@ParameterizedTest
public void testTimeDistributedDense(RNNFormat rnnDataFormat){
for( int rnnType=0; rnnType<3; rnnType++ ) {
for( int ffType=0; ffType<3; ffType++ ) {
for( int rnnType = 0; rnnType < 3; rnnType++ ) {
for( int ffType = 0; ffType < 3; ffType++ ) {
Layer l0, l2;
switch (rnnType) {

View File

@ -39,8 +39,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -145,92 +145,6 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-plugin</artifactId>
<version>1.4.30-M1</version>
<configuration>
<args>
<arg>-Xjsr305=strict</arg>
</args>
<compilerPlugins>
<plugin>spring</plugin>
<plugin>jpa</plugin>
</compilerPlugins>
</configuration>
<dependencies>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-allopen</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-noarg</artifactId>
<version>${kotlin.version}</version>
</dependency>
</dependencies>
<executions>
<execution>
<id>compile</id>
<goals> <goal>compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/main/stubs</sourceDir>
<sourceDir>${project.basedir}/src/main/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/main/java</sourceDir>
<sourceDir>${project.basedir}/src/main/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
<execution>
<id>test-compile</id>
<goals> <goal>test-compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/test/stubs</sourceDir>
<sourceDir>${project.basedir}/src/test/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/test/java</sourceDir>
<sourceDir>${project.basedir}/src/test/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
</executions>
</plugin>
<!-- https://kotlinlang.org/docs/reference/using-maven.html -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.5.1</version>
<executions>
<!-- Replacing default-compile as it is treated specially by maven -->
<execution>
<id>default-compile</id>
<phase>none</phase>
</execution>
<!-- Replacing default-testCompile as it is treated specially by maven -->
<execution>
<id>default-testCompile</id>
<phase>none</phase>
</execution>
<execution>
<id>java-compile</id>
<phase>compile</phase>
<goals> <goal>compile</goal> </goals>
</execution>
<execution>
<id>java-test-compile</id>
<phase>test-compile</phase>
<goals> <goal>testCompile</goal> </goals>
</execution>
</executions>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins>
</build>
@ -244,7 +158,10 @@
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib-jdk8</artifactId>
@ -261,11 +178,14 @@
<groupId>org.nd4j</groupId>
<artifactId>samediff-import-tensorflow</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>samediff-import-onnx</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>

View File

@ -22,11 +22,6 @@ package org.nd4j;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.imports.tfgraphs.TFGraphTestAllLibnd4j;
import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
import org.nd4j.imports.tfgraphs.TFGraphTestList;
import org.nd4j.imports.tfgraphs.TFGraphTestZooModels;
import org.nd4j.imports.listeners.ImportModelDebugger;
import java.util.*;
@Slf4j
@ -36,11 +31,6 @@ public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>(Arrays.asList(
TFGraphTestAllSameDiff.class,
TFGraphTestAllLibnd4j.class,
TFGraphTestList.class,
TFGraphTestZooModels.class,
ImportModelDebugger.class //Run manually only, otherwise ignored
));
}

View File

@ -20,19 +20,16 @@
package org.nd4j;
import org.bytedeco.javacpp.Loader;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.nd4j.autodiff.opvalidation.*;
import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
//import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.function.Function;
import static org.junit.Assume.assumeFalse;
@ -49,7 +46,7 @@ import static org.junit.Assume.assumeFalse;
TransformOpValidation.class,
//TF import tests
TFGraphTestAllSameDiff.class
//TFGraphTestAllSameDiff.class
//TFGraphTestAllLibnd4j.class
})
//IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test"

View File

@ -27,10 +27,12 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.ImportClassMapping;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.NoOp;
import org.nd4j.linalg.api.ops.compat.CompatSparseToDense;
@ -122,13 +124,11 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Disabled("No longer relevant after model import rewrite.")
public class TestOpMapping extends BaseNd4jTest {
public class TestOpMapping extends BaseNd4jTestWithBackends {
Set<Class<? extends DifferentialFunction>> subTypes;
public TestOpMapping(Nd4jBackend b){
super(b);
public TestOpMapping() {
Reflections reflections = new Reflections("org.nd4j");
subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
}
@ -146,6 +146,8 @@ public class TestOpMapping extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpMappingCoverage() throws Exception {
Map<String, DifferentialFunction> opNameMapping = ImportClassMapping.getOpNameMapping();
Map<String, DifferentialFunction> tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions();
@ -196,7 +198,9 @@ public class TestOpMapping extends BaseNd4jTest {
}
@Test
public void testOpsInNamespace() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpsInNamespace(Nd4jBackend backend) throws Exception {
//Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't
// want to add to a namespace for some reason)
//Note that we ignore "*Bp", "*Gradient", "*Derivative" etc ops
@ -354,8 +358,11 @@ public class TestOpMapping extends BaseNd4jTest {
s.add(Assign.class);
}
@Test @Disabled
public void generateOpClassList() throws Exception{
@Test
@Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void generateOpClassList(Nd4jBackend backend) throws Exception{
Reflections reflections = new Reflections("org.nd4j");
Set<Class<? extends DifferentialFunction>> subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
@ -366,12 +373,7 @@ public class TestOpMapping extends BaseNd4jTest {
l.add(c);
}
Collections.sort(l, new Comparator<Class<?>>() {
@Override
public int compare(Class<?> o1, Class<?> o2) {
return o1.getName().compareTo(o2.getName());
}
});
Collections.sort(l, Comparator.comparing(Class::getName));
for(Class<?> c : l){
System.out.println(c.getName() + ".class,");

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable;
@ -31,7 +33,7 @@ import org.nd4j.autodiff.samediff.internal.FrameIter;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.memory.NoOpMemoryMgr;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -46,19 +48,17 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
public class TestSessions extends BaseNd4jTest {
public TestSessions(Nd4jBackend b){
super(b);
}
public class TestSessions extends BaseNd4jTestWithBackends {
@Override
public char ordering(){
public char ordering() {
return 'c';
}
@Test
public void testInferenceSessionBasic(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInferenceSessionBasic(Nd4jBackend backend) {
//So far: trivial test to check execution order
SameDiff sd = SameDiff.create();
@ -90,7 +90,9 @@ public class TestSessions extends BaseNd4jTest {
@Test
public void testInferenceSessionBasic2(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInferenceSessionBasic2(Nd4jBackend backend) {
//So far: trivial test to check execution order
SameDiff sd = SameDiff.create();
@ -126,7 +128,9 @@ public class TestSessions extends BaseNd4jTest {
}
@Test
public void testMergeSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeSimple(Nd4jBackend backend) {
//This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available...
SameDiff sd = SameDiff.create();
@ -162,7 +166,9 @@ public class TestSessions extends BaseNd4jTest {
@Test
public void testSwitchSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSwitchSimple(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3,3);

View File

@ -21,10 +21,12 @@
package org.nd4j.autodiff.internal;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.DependencyList;
import org.nd4j.autodiff.samediff.internal.DependencyTracker;
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,19 +37,18 @@ import java.util.Collections;
import static junit.framework.TestCase.assertNotNull;
import static org.junit.jupiter.api.Assertions.*;
public class TestDependencyTracker extends BaseNd4jTest {
public class TestDependencyTracker extends BaseNd4jTestWithBackends {
public TestDependencyTracker(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void testSimple(){
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>();
@ -93,8 +94,10 @@ public class TestDependencyTracker extends BaseNd4jTest {
assertTrue(dt.isEmpty());
}
@Test
public void testSatisfiedBeforeAdd(){
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSatisfiedBeforeAdd(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>();
//Check different order of adding dependencies: i.e., mark X as satisfied, then add x -> y dependency
@ -132,8 +135,10 @@ public class TestDependencyTracker extends BaseNd4jTest {
assertFalse(dt.hasNewAllSatisfied());
}
@Test
public void testMarkUnsatisfied(){
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMarkUnsatisfied(Nd4jBackend backend){
DependencyTracker<String,String> dt = new DependencyTracker<>();
dt.addDependency("y", "x");
@ -164,7 +169,9 @@ public class TestDependencyTracker extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIdentityDependencyTracker(){
IdentityDependencyTracker<INDArray, String> dt = new IdentityDependencyTracker<>();
assertTrue(dt.isEmpty());

View File

@ -21,6 +21,8 @@
package org.nd4j.autodiff.opvalidation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.GradCheckUtil;
@ -38,12 +40,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class ActivationGradChecks extends BaseOpValidation {
public ActivationGradChecks(Nd4jBackend backend) {
super(backend);
}
@Test
public void testActivationGradientCheck1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testActivationGradientCheck1(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4));
@ -61,7 +62,9 @@ public class ActivationGradChecks extends BaseOpValidation {
}
@Test
public void testActivationGradientCheck2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testActivationGradientCheck2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4);

View File

@ -21,18 +21,14 @@
package org.nd4j.autodiff.opvalidation;
import org.junit.jupiter.api.BeforeEach;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
public abstract class BaseOpValidation extends BaseNd4jTest {
public abstract class BaseOpValidation extends BaseNd4jTestWithBackends {
private DataType initialType;
private DataType initialType = Nd4j.dataType();
public BaseOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {

View File

@ -27,6 +27,8 @@ import java.util.List;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.OpValidation;
@ -65,9 +67,6 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class LayerOpValidation extends BaseOpValidation {
public LayerOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override
public long getTimeoutMilliseconds() {
@ -75,7 +74,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testXwPlusB() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testXwPlusB(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create();
@ -109,7 +110,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testReluLayer() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReluLayer(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create();
@ -137,7 +140,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testBiasAdd() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sameDiff = SameDiff.create();
@ -161,7 +166,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testConv2d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2d(Nd4jBackend backend) {
//avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling
//Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d
@ -301,7 +308,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLrn2d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLrn2d(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};
@ -342,7 +351,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testIm2Col() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIm2Col(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873
Nd4j.getRandom().setSeed(12345);
@ -381,7 +392,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testOutputShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOutputShape(Nd4jBackend backend) {
long[] inSize = {1, 8, 8, 3};
SameDiff sd = SameDiff.create();
@ -431,7 +444,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testAvgPool() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPool(Nd4jBackend backend) {
long[] inSize = {1, 8, 8, 3}; //NHWC
Pooling2DConfig conf = Pooling2DConfig.builder()
@ -474,7 +489,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testConv3d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3d(Nd4jBackend backend) {
//Pooling3d, Conv3D, batch norm
Nd4j.getRandom().setSeed(12345);
@ -576,7 +593,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testDepthWiseConv2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthWiseConv2dBasic(Nd4jBackend backend) {
int nIn = 3;
int depthWise = 4;
int kH = 2;
@ -615,7 +634,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testSeparableConv2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSeparableConv2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nIn = 2;
int nOut = 3;
@ -671,7 +692,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testDeconv2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeconv2dBasic(Nd4jBackend backend) {
int nIn = 2;
int nOut = 3;
int kH = 2;
@ -715,7 +738,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testConv2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2dBasic(Nd4jBackend backend) {
int nIn = 3;
int nOut = 4;
int kH = 2;
@ -756,7 +781,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testMaxPoolingArgMax() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPoolingArgMax(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nIn = 3;
int kH = 2;
@ -785,7 +812,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testMaxPooling2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPooling2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nIn = 3;
int kH = 2;
@ -843,7 +872,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testAvgPooling2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPooling2dBasic(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nIn = 3;
int kH = 2;
@ -892,7 +923,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testAvgPooling3dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAvgPooling3dBasic(Nd4jBackend backend) {
int nIn = 3;
int kH = 2;
int kW = 2;
@ -929,7 +962,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testMaxPooling3dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxPooling3dBasic(Nd4jBackend backend) {
int nIn = 3;
int kH = 2;
int kW = 2;
@ -967,7 +1002,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testConv1dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dBasic(Nd4jBackend backend) {
int nIn = 3;
int nOut = 4;
int k = 2;
@ -1002,7 +1039,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testConv1dCausal() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dCausal(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nIn = 3;
int nOut = 4;
@ -1051,7 +1090,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testConv1dForward() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1dForward(Nd4jBackend backend) {
int nIn = 2;
int nOut = 1;
int kernel = 3;
@ -1094,7 +1135,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testConv3dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3dBasic(Nd4jBackend backend) {
int nIn = 3;
int nOut = 4;
int kH = 2;
@ -1140,7 +1183,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testDeConv3dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv3dBasic(Nd4jBackend backend) {
int nIn = 4;
int nOut = 3;
int kH = 2;
@ -1185,7 +1230,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLayerNorm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNorm(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1210,7 +1257,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLayerNorm4d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNorm4d(Nd4jBackend backend) {
int mb = 3;
int ch = 4;
for (boolean nchw : new boolean[]{true, false}) {
@ -1242,7 +1291,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testLayerNormOP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormOP(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1258,7 +1309,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLayerNormNoBias() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormNoBias(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1281,7 +1334,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLayerNormOPNoBias() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormOPNoBias(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
final INDArray standardized = random.ulike();
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
@ -1296,7 +1351,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testLayerNormNoDeviation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormNoDeviation(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
for (int i = 0; i < 4; i++) {
random.putScalar(1, i, 7);
@ -1326,36 +1383,36 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test()
public void exceptionThrown_WhenConv1DConfigInvalid() {
assertThrows(IllegalArgumentException.class,() -> {
int nIn = 3;
int nOut = 4;
int k = 2;
int mb = 3;
int img = 28;
public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
int nIn = 3;
int nOut = 4;
int k = 2;
int mb = 3;
int img = 28;
SameDiff sd = SameDiff.create();
INDArray wArr = Nd4j.create(k, nIn, nOut);
INDArray inArr = Nd4j.create(mb, nIn, img);
SameDiff sd = SameDiff.create();
INDArray wArr = Nd4j.create(k, nIn, nOut);
INDArray inArr = Nd4j.create(mb, nIn, img);
SDVariable in = sd.var("in", inArr);
SDVariable w = sd.var("W", wArr);
SDVariable in = sd.var("in", inArr);
SDVariable w = sd.var("W", wArr);
SDVariable[] vars = new SDVariable[]{in, w};
SDVariable[] vars = new SDVariable[]{in, w};
Conv1DConfig conv1DConfig = Conv1DConfig.builder()
.k(k).p(-1).s(0)
.paddingMode(PaddingMode.VALID)
.build();
Conv1DConfig conv1DConfig = Conv1DConfig.builder()
.k(k).p(-1).s(0)
.paddingMode(PaddingMode.VALID)
.build();
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
});
});
}
@Test()
public void exceptionThrown_WhenConv2DConfigInvalid() {
public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345);
@ -1378,40 +1435,42 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test()
public void exceptionThrown_WhenConf3DInvalid() {
assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345);
public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345);
//NCDHW format
int[] inSizeNCDHW = {2, 3, 4, 5, 5};
//NCDHW format
int[] inSizeNCDHW = {2, 3, 4, 5, 5};
List<String> failed = new ArrayList<>();
List<String> failed = new ArrayList<>();
for (boolean ncdhw : new boolean[]{true, false}) {
int nIn = inSizeNCDHW[1];
int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW));
for (boolean ncdhw : new boolean[]{true, false}) {
int nIn = inSizeNCDHW[1];
int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW));
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", shape);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", shape);
SDVariable out;
String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape);
SDVariable out;
String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape);
SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC]
SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10));
out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder()
.dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC)
.isSameMode(true)
.kH(2).kW(2).kD(2)
.sD(1).sH(1).sW(-1).dW(-1)
.build());
}
});
SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC]
SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10));
out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder()
.dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC)
.isSameMode(true)
.kH(2).kW(2).kD(2)
.sD(1).sH(1).sW(-1).dW(-1)
.build());
}
});
}
@Test
public void testLayerNormMixedOrders() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLayerNormMixedOrders(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
@ -1458,7 +1517,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void testBiasAdd_nchw_nhwc() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd_nchw_nhwc(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
for (boolean nchw : new boolean[]{true, false}) {
@ -1489,6 +1550,8 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthwiseConv2D(){
int bS = 10;
@ -1527,7 +1590,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void LSTMLayerTestCase1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase1(Nd4jBackend backend) {
int bS = 5;
int nIn = 3;
@ -1602,7 +1667,9 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void LSTMLayerTestCase2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase2(Nd4jBackend backend) {
int bS = 5;
int nIn = 3;
int numUnits = 7;
@ -1660,7 +1727,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void LSTMLayerTestCase3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void LSTMLayerTestCase3(Nd4jBackend backend) {
int bS = 5;
int nIn = 3;
int numUnits = 7;
@ -1721,7 +1790,9 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test
public void GRUTestCase() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void GRUTestCase(Nd4jBackend backend) {
int bS = 5;
int nIn = 4;
int nOut = 6;

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
@ -43,9 +45,7 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class LossOpValidation extends BaseOpValidation {
public LossOpValidation(Nd4jBackend backend) {
super(backend);
}
@Override
public long getTimeoutMilliseconds() {
@ -56,7 +56,9 @@ public class LossOpValidation extends BaseOpValidation {
public static final Set<String> NO_BP_YET = new HashSet<>();
@Test
public void testLoss2d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoss2d(Nd4jBackend backend) {
final List<String> oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax");
Nd4j.getRandom().setSeed(12345);
@ -69,7 +71,7 @@ public class LossOpValidation extends BaseOpValidation {
"absdiff", "cosine", "hinge", "huber", "log", "mse",
"sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse",
"sparsesoftmax"
}) {
}) {
for(String weights : new String[]{"none", "scalar", "perExample", "perOutput"}) {
@ -368,6 +370,8 @@ public class LossOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCosineDistance(){
INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}});
INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}});
@ -386,6 +390,8 @@ public class LossOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testL2Loss(){
for( int rank=0; rank<=3; rank++ ){
@ -428,7 +434,9 @@ public class LossOpValidation extends BaseOpValidation {
}
@Test
public void testNonZeroResult() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNonZeroResult(Nd4jBackend backend) {
INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5);
INDArray w = Nd4j.scalar(1.0);
INDArray label = Nd4j.rand(DataType.DOUBLE, 10, 5);
@ -486,6 +494,8 @@ public class LossOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void TestStdLossMixedDataType(){
// Default Data Type in this test suite is Double.
// This test used to throw an Exception that we have mixed data types.

View File

@ -23,6 +23,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -78,13 +80,12 @@ import static org.junit.Assume.assumeNotNull;
@Slf4j
public class MiscOpValidation extends BaseOpValidation {
public MiscOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test
public void testGradientAutoBroadcast1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast1(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
@ -171,7 +172,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testGradientAutoBroadcast2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -260,7 +263,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testGradientAutoBroadcast3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGradientAutoBroadcast3(Nd4jBackend backend) {
//These tests: output size > input sizes
Nd4j.getRandom().setSeed(12345);
@ -368,7 +373,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testScatterOpGradients() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterOpGradients(Nd4jBackend backend) {
List<String> failed = new ArrayList<>();
for (int i = 0; i < 7; i++) {
@ -470,6 +477,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterUpdate(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3);
INDArray updates = Nd4j.create(new float[][]{
@ -491,7 +500,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testGatherGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -542,6 +553,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrace(){
//TODO need to work out how to handle shape_op for scalars...
//OpValidationSuite.ignoreFailing();
@ -567,7 +580,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
public void testTensorGradTensorMmul() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorGradTensorMmul(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing();
Nd4j.getRandom().setSeed(12345);
@ -589,7 +604,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testMulGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMulGradient(Nd4jBackend backend) {
INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
@ -654,22 +671,21 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
public void testMmulGradientManual() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulGradientManual(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
Map<String, INDArray> inputs = new HashMap<>();
inputs.put("x", sumInput);
inputs.put("y", sumInput.dup());
sameDiff.defineFunction("mmulGradient", new SameDiffFunctionDefinition() {
@Override
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
SDVariable input = sameDiff.var("x", inputs.get("x"));
SDVariable input2 = sameDiff.var("y", inputs.get("y"));
SDVariable exp = sameDiff.mmul(input, input2);
SDVariable sum = sameDiff.sum(exp, Integer.MAX_VALUE);
return new SDVariable[]{sum};
}
sameDiff.defineFunction("mmulGradient", (sameDiff1, inputs1, variableInputs) -> {
SDVariable input = sameDiff1.var("x", inputs1.get("x"));
SDVariable input2 = sameDiff1.var("y", inputs1.get("y"));
SDVariable exp = sameDiff1.mmul(input, input2);
SDVariable sum = sameDiff1.sum(exp, Integer.MAX_VALUE);
return new SDVariable[]{sum};
}, inputs);
@ -698,6 +714,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulGradients(){
int[] aShape = new int[]{2,3};
int[] bShape = new int[]{3,4};
@ -749,7 +767,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testBatchMmulBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBatchMmulBasic(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873
int M = 5;
int N = 3;
@ -774,7 +794,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
public void testMmulWithTranspose() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulWithTranspose(Nd4jBackend backend) {
//Here: [x,3]^T * [x,4] = [3,4]
@ -811,6 +833,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulOutputSizeCalculation(){
//[3,2] x [2,4] with result transpose: output shape [4,3]
INDArray a = Nd4j.create(3,2);
@ -820,7 +844,7 @@ public class MiscOpValidation extends BaseOpValidation {
.transposeA(false)
.transposeB(false)
.transposeResult(true)
.build());
.build());
val outShapes = Nd4j.getExecutioner().calculateOutputShape(m);
assertArrayEquals(new long[]{4,3}, outShapes.get(0).getShape());
@ -843,6 +867,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFillOp(){
INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT);
@ -857,6 +883,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm(){
//Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -889,6 +917,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm2(){
//Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -932,6 +962,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm1(){
//Expected: if array.norm2(1) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -972,6 +1004,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByNorm0(){
//Expected: if array.norm2(0) is less than 1.0, not modified
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
@ -1001,6 +1035,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumSum(){
List<String> failing = new ArrayList<>();
@ -1066,6 +1102,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumProd(){
List<String> failing = new ArrayList<>();
@ -1134,6 +1172,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot1(){
List<String> failed = new ArrayList<>();
@ -1164,6 +1204,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHotOp(){
//https://www.tensorflow.org/api_docs/python/tf/one_hot
//https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp
@ -1178,7 +1220,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testOneHot2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot2(Nd4jBackend backend) {
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
@ -1198,7 +1242,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testOneHot4() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot4(Nd4jBackend backend) {
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
@ -1218,7 +1264,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testOneHot3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneHot3(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6872
//https://www.tensorflow.org/api_docs/python/tf/one_hot
@ -1253,6 +1301,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLinspace(){
SameDiff sd = SameDiff.create();
SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10);
@ -1266,6 +1316,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLinspace2(){
OpValidationSuite.ignoreFailing(); //TODO 2019/01/18
SameDiff sd = SameDiff.create();
@ -1280,7 +1332,9 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
public void testShapeFn() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapeFn(Nd4jBackend backend) {
INDArray in = Nd4j.create(new long[]{1, 2});
@ -1294,7 +1348,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testShapeFn2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapeFn2(Nd4jBackend backend) {
INDArray i = Nd4j.create(1,3);
@ -1307,6 +1363,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeRank1(){
SameDiff sd = SameDiff.create();
SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5));
@ -1325,7 +1383,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testDiagPart() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagPart(Nd4jBackend backend) {
INDArray i = Nd4j.create(5,5);
SameDiff sd = SameDiff.create();
@ -1337,7 +1397,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testDiagShapeFn() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagShapeFn(Nd4jBackend backend) {
INDArray i = Nd4j.create(5,5);
CustomOp op = new DiagPart(i, null);
@ -1350,6 +1412,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZerosOnesLike(){
Nd4j.getRandom().setSeed(12345);
@ -1392,6 +1456,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZerosLikeOp(){
INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0);
@ -1407,6 +1473,8 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrix(){
DataType dt = DataType.DOUBLE;
@ -1443,6 +1511,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIsNonDecreasingIsStrictlyIncr(){
List<long[]> shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3});
@ -1506,6 +1576,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExtractImagePatches(){
/*
tf.reset_default_graph()
@ -1553,6 +1625,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentProdBpSimple(){
INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT);
@ -1573,6 +1647,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulRank4() throws Exception {
Nd4j.getRandom().setSeed(12345);
@ -1608,6 +1684,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulRank4_simple(){
INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
@ -1634,6 +1712,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNthElementRank1(){
INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9});
INDArray n = Nd4j.scalar(0);
@ -1656,6 +1736,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorMmulShape(){
INDArray a = Nd4j.create(new double[]{2}).reshape(1);
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
@ -1674,6 +1756,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTensorMmulShape2(){
INDArray a = Nd4j.create(new double[]{2}).reshape(1);
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
@ -1682,6 +1766,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStopGradient(){
SameDiff sd = SameDiff.create();
@ -1701,6 +1787,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckNumerics(){
OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927
@ -1744,7 +1832,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testCheckNumerics2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckNumerics2(Nd4jBackend backend) {
INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4);
INDArray msg = Nd4j.scalar("My error message!");
@ -1757,6 +1847,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testHistogramFixedWidth(){
//Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf]
INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9);
@ -1775,6 +1867,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition(){
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
@ -1793,6 +1887,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testListDiff(){
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
INDArray y = Nd4j.createFromArray(3, 1);
@ -1812,7 +1908,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testDivideNoNan() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDivideNoNan(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff()
SameDiff sameDiff = SameDiff.create();
@ -1836,7 +1934,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testDigamma() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDigamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1851,7 +1951,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testFlatten() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFlatten(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1873,7 +1975,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testFusedBatchNorm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFusedBatchNorm(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();
@ -1918,7 +2022,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testIgamma() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIgamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1934,7 +2040,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testIgammaC() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIgammaC(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -1951,7 +2059,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testLgamma() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLgamma(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1976,7 +2086,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testLu() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLu(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -2007,7 +2119,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testMatrixBandPart() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixBandPart(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();
@ -2037,7 +2151,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testPolygamma() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPolygamma(Nd4jBackend backend) {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -2053,7 +2169,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testTriangularSolve() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriangularSolve(Nd4jBackend backend) {
INDArray a = Nd4j.createFromArray(new float[]{
3.f, 0.f, 0.f, 0.f,
@ -2077,7 +2195,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testBiasAdd() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAdd(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -2106,7 +2226,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testBiasAddGrad() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBiasAddGrad(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -2126,7 +2248,9 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
public void testRoll() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRoll(Nd4jBackend backend) {
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}).
@ -2146,6 +2270,8 @@ public class MiscOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSeqMask(){
INDArray arr = Nd4j.createFromArray(1,2,3);
INDArray maxLen = Nd4j.scalar(4);

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -51,12 +53,11 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class RandomOpValidation extends BaseOpValidation {
public RandomOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test
public void testRandomOpsSDVarShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomOpsSDVarShape(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -157,7 +158,9 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
public void testRandomOpsLongShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomOpsLongShape(Nd4jBackend backend) {
List<String> failed = new ArrayList<>();
for (long[] shape : Arrays.asList(new long[]{1000}, new long[]{100, 10}, new long[]{40, 5, 5})) {
@ -283,6 +286,8 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomBinomial(){
INDArray z = Nd4j.create(new long[]{10});
@ -293,7 +298,9 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
public void testUniformRankSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUniformRankSimple(Nd4jBackend backend) {
INDArray arr = Nd4j.createFromArray(new double[]{100.0});
// OpTestCase tc = new OpTestCase(DynamicCustomOp.builder("randomuniform")
@ -325,7 +332,9 @@ public class RandomOpValidation extends BaseOpValidation {
@Test
public void testRandomExponential() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomExponential(Nd4jBackend backend) {
long length = 1_000_000;
INDArray shape = Nd4j.createFromArray(new double[]{length});
INDArray out = Nd4j.createUninitialized(new long[]{length});
@ -347,6 +356,8 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRange(){
//Technically deterministic, not random...
@ -380,6 +391,8 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllEmptyReduce(){
INDArray x = Nd4j.createFromArray(true, true, true);
All all = new All(x);
@ -389,6 +402,8 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUniformDtype(){
Nd4j.getRandom().setSeed(12345);
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
@ -417,6 +432,8 @@ public class RandomOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomExponential2(){
Nd4j.getRandom().setSeed(12345);
DynamicCustomOp op = DynamicCustomOp.builder("random_exponential")

View File

@ -24,6 +24,8 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.validation.OpTestCase;
import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.linalg.api.buffer.DataType;
@ -51,10 +53,6 @@ public class ReductionBpOpValidation extends BaseOpValidation {
private DataType initialType;
public ReductionBpOpValidation(Nd4jBackend backend) {
super(backend);
}
@BeforeEach
public void before() {
Nd4j.create(1);
@ -71,14 +69,16 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@AfterEach
public void tearDown() {
public void tearDown(Nd4jBackend backend) {
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
}
@Test
public void testReduceSumBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumBP(Nd4jBackend backend) {
//Full array reduction
//reduce_sum_bp op: has 2 inputs (original pre-reduce input, and gradient at output (epsilon))
@ -104,7 +104,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testReduceSumAlongDim0BP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumAlongDim0BP(Nd4jBackend backend) {
//Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar
@ -130,7 +132,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testReduceSumAlongDim1BP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduceSumAlongDim1BP(Nd4jBackend backend) {
//Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar
@ -158,7 +162,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test
public void testMeanBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanBP(Nd4jBackend backend) {
//dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j))
// = 1/N * dL/dOut
@ -189,7 +195,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMeanBP_Rank1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanBP_Rank1(Nd4jBackend backend) {
INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5 / 3);
@ -202,7 +210,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMeanAlongDim0BP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanAlongDim0BP(Nd4jBackend backend) {
//Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar
@ -230,7 +240,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMeanAlongDim1BP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeanAlongDim1BP(Nd4jBackend backend) {
//Reduction along dimension
//Inputs/outputs as before - but note that the output is no longer a scalar
@ -258,7 +270,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test
public void testMinBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMinBP(Nd4jBackend backend) {
//Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i
@ -297,7 +311,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMinAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMinAlongDimensionBP(Nd4jBackend backend) {
//Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i
@ -340,7 +356,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMaxBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxBP(Nd4jBackend backend) {
//Full array max reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i
@ -370,7 +388,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testMaxAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxAlongDimensionBP(Nd4jBackend backend) {
//Full array min reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i
@ -413,7 +433,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testProdBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProdBP(Nd4jBackend backend) {
//Full array product reduction
//dL/dIn_i = dL/dOut * dOut/dIn_i
@ -442,7 +464,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testProdAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProdAlongDimensionBP(Nd4jBackend backend) {
//dL/dIn_i = dL/dOut * dOut/dIn_i
// = dL/dOut * d(prod(in))/dIn_i
// = dL/dOut * (prod(in) / in_i)
@ -498,7 +522,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testStdevBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevBP(Nd4jBackend backend) {
//If out = stdev(in) then:
//dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
@ -534,7 +560,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testStdevBP_Rank1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevBP_Rank1(Nd4jBackend backend) {
INDArray dLdOut = Nd4j.scalar(0.5);
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
double stdev = preReduceInput.stdNumber(true).doubleValue();
@ -555,7 +583,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testStdevAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdevAlongDimensionBP(Nd4jBackend backend) {
//If out = stdev(in) then:
//dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
@ -600,7 +630,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testVarianceBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVarianceBP(Nd4jBackend backend) {
//If out = variance(in) then:
//dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = 2*(in_i-mean)/(n-1)
@ -636,7 +668,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testVarianceAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVarianceAlongDimensionBP(Nd4jBackend backend) {
//If out = variance(in) then:
//dL/dIn = dL/dOut * dOut/dIn
//dOut/dIn_i = 2*(in_i-mean)/(n-1)
@ -678,7 +712,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test
public void testCumSumBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCumSumBP(Nd4jBackend backend) {
//Standard case, non-reverse, non-exclusive
//dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i
// = sum_j dL/dOut_j * d(in_0 + ... + in_j)/dIn_i
@ -748,7 +784,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@Test
public void testNorm2Bp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2Bp(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * x/|x|_2
@ -775,7 +813,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testNorm2AlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2AlongDimensionBP(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * x/|x|_2
@ -808,7 +848,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testNorm1Bp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1Bp(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * sgn(in)
@ -835,7 +877,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testNorm1AlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1AlongDimensionBP(Nd4jBackend backend) {
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * sgn(in)
@ -867,7 +911,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testNormMaxBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMaxBp(Nd4jBackend backend) {
//out = max_i (|in_i|)
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise)
@ -897,7 +943,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
}
@Test
public void testNormMaxAlongDimensionBP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMaxAlongDimensionBP(Nd4jBackend backend) {
//out = max_i (|in_i|)
//dL/dIn = dL/dOut * dOut/dIn
// = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise)

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -76,16 +77,13 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@Slf4j
@RunWith(Parameterized.class)
public class ReductionOpValidation extends BaseOpValidation {
public ReductionOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test
public void testStdev() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStdev(Nd4jBackend backend) {
List<String> errors = new ArrayList<>();
for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE)) {
@ -111,7 +109,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testZeroCount() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeroCount(Nd4jBackend backend) {
List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 21; i++) {
SameDiff sd = SameDiff.create();
@ -145,7 +145,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test
public void testZeroFraction() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeroFraction(Nd4jBackend backend) {
List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 2; i++) {
SameDiff sd = SameDiff.create();
@ -175,7 +177,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testReductionGradientsSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradientsSimple(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
//Test reductions: final and only function
Nd4j.getRandom().setSeed(12345);
@ -344,7 +348,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testReductionGradients1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradients1(Nd4jBackend backend) {
//Test reductions: final, but *not* the only function
Nd4j.getRandom().setSeed(12345);
@ -472,7 +478,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testReductionGradients2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionGradients2(Nd4jBackend backend) {
//Test reductions: NON-final function
Nd4j.getRandom().setSeed(12345);
@ -650,7 +658,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test
public void testReduce3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduce3(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int d0 = 3;
@ -755,7 +765,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testMoments() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMoments(Nd4jBackend backend) {
for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) {
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
@ -787,9 +799,11 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testMomentsOp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMomentsOp(Nd4jBackend backend) {
int[] axes = new int[]{0};
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray outMean = Nd4j.createUninitialized(new long[]{4});
INDArray outVar = Nd4j.createUninitialized(new long[]{4});
@ -804,7 +818,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testNormalizeMomentsOp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormalizeMomentsOp(Nd4jBackend backend) {
INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
INDArray ssSum = data.sum(0);
INDArray ssSqSum = data.mul(data).sum(0);
@ -824,7 +840,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testAllAny() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllAny(Nd4jBackend backend) {
INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4);
INDArray allOnes = Nd4j.ones(DataType.FLOAT, 3, 4);
@ -852,7 +870,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testIndexAccum() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexAccum(Nd4jBackend backend) {
List<String> failed = new ArrayList<>();
List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/);
@ -941,7 +961,9 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test
public void testReduce3_2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReduce3_2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int d0 = 3;
@ -1039,7 +1061,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testReductionsBackwards() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionsBackwards(Nd4jBackend backend) {
// for (int i = 0; i < 7; i++) {
int i=5;
{
@ -1108,6 +1132,8 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttention(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1127,12 +1153,14 @@ public class ReductionOpValidation extends BaseOpValidation {
t.norm1("out");
String err = OpValidation.validate(new TestCase(sd)
.expectedOutput("out", finalOut)
.gradientCheck(true));
.expectedOutput("out", finalOut)
.gradientCheck(true));
assertNull(err);
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionWithMask(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1163,6 +1191,8 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionMultiHeadInputWithMask(){
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
@ -1194,6 +1224,8 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionMultiHeadInput(){
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
@ -1221,6 +1253,8 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiHeadedDotProductAttention(){
final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
@ -1272,6 +1306,8 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDotProductAttentionWeirdInputs(){
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
@ -1309,6 +1345,8 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiHeadedDotProductAttentionWeirdInputs(){
final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
@ -1366,7 +1404,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
}
@Test
public void testSufficientStatisticsOp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSufficientStatisticsOp(Nd4jBackend backend) {
INDArray data = Nd4j.createFromArray(new double[]{
5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5
@ -1392,7 +1432,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testStandardDeviation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardDeviation(Nd4jBackend backend) {
for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create();
@ -1419,7 +1461,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testSquaredNorm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSquaredNorm(Nd4jBackend backend) {
for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create();
@ -1442,7 +1486,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testShannonEntropy() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShannonEntropy(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
SameDiff sameDiff = SameDiff.create();
@ -1462,7 +1508,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testEntropy() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEntropy(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1481,7 +1529,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testAMean() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAMean(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1502,7 +1552,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testMean() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMean(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1523,7 +1575,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testNorm1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm1(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1544,7 +1598,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testNorm2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNorm2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1565,7 +1621,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testNormMax() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNormMax(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
@ -1586,7 +1644,9 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
public void testSoftmaxCrossEntropyWithLogitsLoss() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxCrossEntropyWithLogitsLoss(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();

View File

@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
@ -43,12 +45,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j
public class RnnOpValidation extends BaseOpValidation {
public RnnOpValidation(Nd4jBackend backend) {
super(backend);
}
@Test
public void testRnnBlockCell(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRnnBlockCell(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int mb = 2;
int nIn = 3;
@ -147,7 +148,9 @@ public class RnnOpValidation extends BaseOpValidation {
@Test
public void testRnnBlockCellManualTFCompare() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRnnBlockCellManualTFCompare(Nd4jBackend backend) {
//Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS"
SameDiff sd = SameDiff.create();
@ -209,6 +212,8 @@ public class RnnOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGRUCell(){
Nd4j.getRandom().setSeed(12345);
int mb = 2;

View File

@ -28,6 +28,8 @@ import lombok.val;
import org.apache.commons.math3.linear.LUDecomposition;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -67,9 +69,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
@Slf4j
public class ShapeOpValidation extends BaseOpValidation {
public ShapeOpValidation(Nd4jBackend backend) {
super(backend);
}
/*
To test:
@ -83,7 +82,9 @@ public class ShapeOpValidation extends BaseOpValidation {
*/
@Test
public void testConcat() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat(Nd4jBackend backend) {
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
int[] concatDim = new int[]{0, 0, 0};
List<List<int[]>> origShapes = new ArrayList<>();
@ -123,7 +124,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testReshapeGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshapeGradient(Nd4jBackend backend) {
//https://github.com/deeplearning4j/deeplearning4j/issues/6873
int[] origShape = new int[]{3, 4, 5};
@ -159,7 +162,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testPermuteGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteGradient(Nd4jBackend backend) {
int[] origShape = new int[]{3, 4, 5};
List<String> failed = new ArrayList<>();
@ -197,6 +202,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRank(){
List<long[]> inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5});
@ -224,7 +231,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testExpandDimsGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExpandDimsGradient(Nd4jBackend backend) {
val origShape = new long[]{3, 4};
List<String> failed = new ArrayList<>();
@ -280,7 +289,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testSqueezeGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSqueezeGradient(Nd4jBackend backend) {
val origShape = new long[]{3, 4, 5};
List<String> failed = new ArrayList<>();
@ -344,7 +355,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testSliceGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSliceGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
//Order here: original shape, begin, size
@ -434,7 +447,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSliceGradient() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceGradient(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
//Order here: original shape, begin, size
@ -497,7 +512,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testMerge() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMerge(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -573,7 +590,7 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test()
public void testStack() {
public void testStack(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -664,7 +681,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testUnStack() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnStack(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -752,7 +771,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testTile() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTile(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
List<int[]> tileArg = Arrays.asList(
@ -824,6 +845,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTileBp(){
Nd4j.getRandom().setSeed(12345);
@ -857,6 +880,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTileBp2(){
Nd4j.getRandom().setSeed(12345);
@ -891,7 +916,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testReshape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshape(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4);
SDVariable x = sameDiff.var("x", arr);
@ -907,7 +934,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testReshape2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReshape2(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int[] origShape = new int[]{3, 4, 5};
@ -930,7 +959,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testTranspose() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTranspose(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
SDVariable x = sameDiff.var("x", arr);
@ -942,6 +973,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTransposeOp(){
INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3);
@ -955,7 +988,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShape(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
val shape = new long[]{2, 3};
SDVariable x = sameDiff.var("x", shape);
@ -970,7 +1005,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testSize() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSize(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
val shape = new long[]{2, 3};
SDVariable x = sameDiff.var("x", DataType.FLOAT, shape);
@ -984,7 +1021,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testDiagShapeFn() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagShapeFn(Nd4jBackend backend) {
INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4);
OpTestCase op = new OpTestCase(new DiagPart(i, null));
@ -998,6 +1037,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute(){
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
INDArray exp = in.permute(0,1,2); //No op
@ -1012,6 +1053,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute2(){
for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) {
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
@ -1032,6 +1075,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConstant(){
//OpValidationSuite.ignoreFailing();
@ -1059,6 +1104,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnstackEdgeCase2(){
for( int i=0; i<3; i++ ) {
@ -1073,7 +1120,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void invertPermutation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void invertPermutation(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new float[] {3, 4, 0, 2, 1}).castTo(DataType.INT);
@ -1090,6 +1139,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherNd(){
List<INDArray> indices = new ArrayList<>();
@ -1128,7 +1179,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testReverseSequence() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReverseSequence(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
float[] input_data = new float[]{
1, 2, 3,
@ -1174,6 +1227,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant(){
OpValidationSuite.ignoreFailing(); //Gradient check failing
@ -1195,6 +1250,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeterminant22(){
OpValidationSuite.ignoreFailing(); //Gradient check failing
@ -1219,6 +1276,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant3(){
OpValidationSuite.ignoreFailing(); //Gradient checks failing
Nd4j.getRandom().setSeed(12345);
@ -1250,6 +1309,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrixDeterminant4(){
OpValidationSuite.ignoreFailing(); //Gradient checks failing
Nd4j.getRandom().setSeed(12345);
@ -1270,6 +1331,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentOps(){
OpValidationSuite.ignoreFailing();
//https://github.com/deeplearning4j/deeplearning4j/issues/6952
@ -1362,6 +1425,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentMean(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3);
INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2);
@ -1382,7 +1447,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testSequenceMask() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSequenceMask(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
// arr is not trainable, so it's constant in model
@ -1391,10 +1458,10 @@ public class ShapeOpValidation extends BaseOpValidation {
// Test with static max len
int maxlen = 5;
INDArray expected = Nd4j.create(new float[] {
1.f, 0.f, 0.f, 0.f, 0.f,
1.f, 1.f, 1.f, 0.f, 0.f,
1.f, 1.f, 0.f, 0.f, 0.f
}).reshape(3,5);
1.f, 0.f, 0.f, 0.f, 0.f,
1.f, 1.f, 1.f, 0.f, 0.f,
1.f, 1.f, 0.f, 0.f, 0.f
}).reshape(3,5);
INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.FLOAT));
SDVariable result1 = sameDiff.sequenceMask(lengths, maxlen, DataType.FLOAT);
assertArrayEquals(expected.shape(), result1.eval().shape());
@ -1416,6 +1483,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMeshGrid(){
List<String> failed = new ArrayList<>();
@ -1472,6 +1541,8 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGather(){
List<INDArray> inArrs = new ArrayList<>();
List<Integer> axis = new ArrayList<>();
@ -1541,7 +1612,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testGatherSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherSimple(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2});
SDVariable x = sameDiff.var("x", arr);
@ -1551,7 +1624,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testGatherNdSingle() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherNdSingle(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4);
INDArray arr2 = Nd4j.create(new float[]{1, 2, 3, 0, 1, 3, 1, 0, 2}, new long[]{3, 3}).castTo(DataType.INT);
@ -1563,14 +1638,16 @@ public class ShapeOpValidation extends BaseOpValidation {
for (int i=0; i<3; i++){
INDArray idx = arr2.get(point(i), NDArrayIndex.all());
expected.putScalar(i, arr1.get(point(idx.getInt(0)),
point(idx.getInt(1)),
point(idx.getInt(2))).getDouble(0));
point(idx.getInt(1)),
point(idx.getInt(2))).getDouble(0));
}
assertEquals(expected, result.eval());
}
@Test
public void testStack2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStack2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
@ -1581,7 +1658,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testParallelStack() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testParallelStack(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
@ -1593,7 +1672,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testUnStack2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnStack2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Nd4j.zeros(3, 2);
INDArray arr2 = Nd4j.ones(3, 2);
@ -1606,7 +1687,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testPermuteSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteSimple(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3));
SDVariable x = sameDiff.var("x", arr);
@ -1617,7 +1700,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testConcat2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(4, 8, 4)).reshape(1,4);
@ -1628,7 +1713,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testTile2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTile2(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4));
SDVariable x = sameDiff.var("x", arr);
@ -1641,7 +1728,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testSlice2d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice2d(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create();
@ -1657,7 +1746,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testSlice3d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice3d(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create();
@ -1672,7 +1763,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSlice2dBasic() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSlice2dBasic(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create();
@ -1690,7 +1783,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testStridedSliceBeginEndMask() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceBeginEndMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
SameDiff sd = SameDiff.create();
@ -1705,7 +1800,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSliceEllipsisMask() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEllipsisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr);
@ -1722,7 +1819,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSliceNewAxisMask() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceNewAxisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr);
@ -1735,7 +1834,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSliceNewAxisMask2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceNewAxisMask2(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr);
@ -1746,7 +1847,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testStridedSliceShrinkAxisMask() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceShrinkAxisMask(Nd4jBackend backend) {
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
SameDiff sd = SameDiff.create();
@ -1763,7 +1866,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testSizeAt_1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSizeAt_1(Nd4jBackend backend) {
val array = Nd4j.create(10, 20, 30);
val exp = Nd4j.scalar(DataType.LONG, 20);
@ -1777,6 +1882,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(){
int[] rows = new int[]{3,3,3,3};
int[] cols = new int[]{3,2,2,2};
@ -1815,6 +1922,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplit1(){
INDArray in = Nd4j.linspace(1,10,10).reshape(10);
INDArray axis = Nd4j.scalar(-1);
@ -1833,6 +1942,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplit2(){
INDArray in = Nd4j.linspace(1,24,24).reshape(3,8);
INDArray axis = Nd4j.scalar(-1);
@ -1851,6 +1962,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDistancesExec(){
//https://github.com/deeplearning4j/deeplearning4j/issues/7001
for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) {
@ -1906,6 +2019,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReductionShape(){
INDArray shape = Nd4j.createFromArray(4,2);
@ -1924,6 +2039,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void gatherTest(){
INDArray in = Nd4j.createFromArray(new double[][]{
{1,2,3,4,5},
@ -1943,6 +2060,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSliceShape(){
INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT);
@ -1964,6 +2083,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testWhereAllFalse(){
INDArray in = Nd4j.create(DataType.BOOL, 1917);
DynamicCustomOp op = DynamicCustomOp.builder("Where")
@ -1978,6 +2099,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherScalar(){
INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100);
INDArray indices = Nd4j.scalar(0);
@ -2002,6 +2125,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCastEmpty(){
INDArray emptyLong = Nd4j.empty(DataType.LONG);
int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h
@ -2018,6 +2143,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGatherEmpty(){
/*
tf.reset_default_graph()
@ -2050,6 +2177,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSplitEmpty(){
/*
tf.reset_default_graph()
@ -2087,6 +2216,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatEmpty(){
/*
TF behaviour with concatenatioun of empty arrays:
@ -2136,6 +2267,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatEmpty2(){
INDArray empty10a = Nd4j.create(DataType.INT, 1, 0);
INDArray empty10b = Nd4j.create(DataType.INT, 1, 0);
@ -2168,6 +2301,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyGather(){
/*
tf.reset_default_graph()
@ -2200,6 +2335,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastDynamicShape1(){
//Test case: [2,1] and [4]: expect [2,4]
@ -2221,6 +2358,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastDynamicShape2(){
//Test case: [2,1,4] and [2,2,4]: expect [2,2,4]
@ -2243,6 +2382,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceShrinkAxis(){
INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2);
INDArray begin = Nd4j.createFromArray(2);
@ -2268,6 +2409,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEmpty(){
INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask!
@ -2290,6 +2433,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStridedSliceEdgeCase(){
INDArray in = Nd4j.scalar(10).reshape(1); //Int [1]
INDArray begin = Nd4j.ones(DataType.INT, 1);
@ -2315,6 +2460,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptySlice1(){
INDArray in = Nd4j.createFromArray(38);
INDArray begin = Nd4j.createFromArray(1);
@ -2334,6 +2481,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptySlice2(){
INDArray in = Nd4j.createFromArray(38);
INDArray begin = Nd4j.createFromArray(0);
@ -2353,6 +2502,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFill(){
INDArray shape = Nd4j.createFromArray(0,4);
@ -2372,6 +2523,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFill2(){
INDArray shape = Nd4j.createFromArray(0,4);
@ -2389,6 +2542,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermuteShapeDynamicAxis(){
DynamicCustomOp op = DynamicCustomOp.builder("permute")
@ -2418,6 +2573,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGather2(){
SameDiff sd = SameDiff.create();
SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3));
@ -2437,6 +2594,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute3(){
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0);
@ -2455,6 +2614,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute4(){
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
INDArray permute = Nd4j.createFromArray(1,0);
@ -2485,6 +2646,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvertPermutation(){
DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation")
.addInputs(Nd4j.createFromArray(1, 0))
@ -2492,7 +2655,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testBroadcastInt1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastInt1(Nd4jBackend backend) {
INDArray out = Nd4j.create(DataType.INT, 1);
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
@ -2505,6 +2670,8 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastInt2(){
INDArray out = Nd4j.create(DataType.INT, 2);
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
@ -2544,7 +2711,9 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testMergeMaxIndex() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeMaxIndex(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2561,7 +2730,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testTriOp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriOp(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable out = new Tri(sd, DataType.INT32, 3, 5, 2).outputVariable();
@ -2573,7 +2744,9 @@ public class ShapeOpValidation extends BaseOpValidation {
}
@Test
public void testTriuOp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTriuOp(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {7,8,9},{10,11,12}}));
@ -2581,8 +2754,8 @@ public class ShapeOpValidation extends BaseOpValidation {
out.markAsLoss();
INDArray expected = Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {0,8,9},{0,0,12}});
String err = OpValidation.validate(new TestCase(sd)
.expectedOutput("triu", expected)
.gradientCheck(true));
.expectedOutput("triu", expected)
.gradientCheck(true));
assertNull(err);
}

View File

@ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
@ -94,9 +96,6 @@ public class TransformOpValidation extends BaseOpValidation {
private DataType initialType;
public TransformOpValidation(Nd4jBackend backend) {
super(backend);
}
@BeforeEach
public void before() {
@ -120,7 +119,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testScalarOps() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarOps(Nd4jBackend backend) {
int d0 = 2;
int d1 = 3;
int d2 = 4;
@ -217,7 +218,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testScalarMulCF() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarMulCF(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
INDArray outC = Nd4j.createUninitialized(3, 4);
@ -231,7 +234,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testScalarMulCF2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarMulCF2(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
@ -242,7 +247,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testCross() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCross(Nd4jBackend backend) {
INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3});
INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3});
@ -270,7 +277,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testSpaceToDepth() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpaceToDepth(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(1337);
int miniBatch = 128;
@ -298,7 +307,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDepthToSpace() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepthToSpace(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(1337);
int miniBatch = 128;
@ -325,7 +336,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testBatchToSpace() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBatchToSpace(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(1337);
@ -362,7 +375,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testSpaceToBatch() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpaceToBatch(Nd4jBackend backend) {
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(7331);
@ -400,7 +415,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDynamicPartition() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{4, 3, 5, 7, 8, 0});
@ -440,7 +457,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDynamicPartition2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicPartition2(Nd4jBackend backend) {
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
@ -458,7 +477,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDynamicStitch() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDynamicStitch(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{5, 1, 3}, new long[]{3});
@ -495,7 +516,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDiag() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiag(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{1, 2}, new int[]{2});
@ -521,7 +544,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testDiagPart() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDiagPart(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
INDArray input = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
@ -540,7 +565,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testEye() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(Nd4jBackend backend) {
int[] rows = new int[]{3, 3, 3, 3};
int[] cols = new int[]{3, 2, 2, 2};
int[][] batch = new int[][]{{}, {}, {4}, {3, 3}};
@ -574,7 +601,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testEyeShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEyeShape(Nd4jBackend backend) {
DynamicCustomOp dco = DynamicCustomOp.builder("eye")
.addIntegerArguments(3, 3)
//.addIntegerArguments(-99,3,3) //Also fails
@ -586,7 +615,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testTransforms() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTransforms(Nd4jBackend backend) {
//Test transforms (non-pairwise)
Nd4j.getRandom().setSeed(12345);
@ -1074,7 +1105,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testPairwiseTransforms() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPairwiseTransforms(Nd4jBackend backend) {
/*
add, sub, mul, div, rsub, rdiv
eq, neq, gt, lt, gte, lte, or, and, xor
@ -1258,7 +1291,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testIsX() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIsX(Nd4jBackend backend) {
List<String> failed = new ArrayList<>();
for (int i = 0; i < 4; i++) {
@ -1313,7 +1348,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testReplaceWhereScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReplaceWhereScalar(Nd4jBackend backend) {
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
log.info("Testing condition: " + c.getClass().getSimpleName());
@ -1335,7 +1372,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testReplaceWhereArray() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReplaceWhereArray(Nd4jBackend backend) {
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
INDArray inArr = Nd4j.rand(3, 4);
@ -1358,7 +1397,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testLogGrad() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogGrad(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE));
SDVariable log = sameDiff.math().log(input);
@ -1369,7 +1410,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testSigmoidBackwards() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSigmoidBackwards(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
Map<String, INDArray> inputs = new HashMap<>();
@ -1386,8 +1429,10 @@ public class TransformOpValidation extends BaseOpValidation {
}
/* @Test
public void testDepth() {
/* @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDepth(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
SDVariable x = sameDiff.one("one",new long[]{2,2});
assertEquals(0,x.depth());
@ -1396,7 +1441,9 @@ public class TransformOpValidation extends BaseOpValidation {
}*/
@Test
public void testRank0EdgeCase() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRank0EdgeCase(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4})));
double d0 = v1.eval().getDouble(0);
@ -1409,7 +1456,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testAtan2BroadcastShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAtan2BroadcastShape(Nd4jBackend backend) {
INDArray arr1 = Nd4j.create(new long[]{3, 1, 4});
INDArray arr2 = Nd4j.create(new long[]{1, 2, 4});
@ -1424,7 +1473,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testBooleanAnd() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBooleanAnd(Nd4jBackend backend) {
Nd4j.setDataType(DataType.FLOAT);
INDArray arr1 = Nd4j.create(new long[]{3, 4});
INDArray arr2 = Nd4j.create(new long[]{3, 4});
@ -1438,7 +1489,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testScatterOpsScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScatterOpsScalar(Nd4jBackend backend) {
for (String s : new String[]{"add", "sub", "mul", "div"}) {
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
INDArray indices = Nd4j.scalar(5);
@ -1483,7 +1536,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540")
@Test
public void testPad() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPad(Nd4jBackend backend) {
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG);
INDArray value = Nd4j.scalar(10.0);
@ -1510,7 +1565,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testMirrorPad() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPad(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
@ -1543,7 +1600,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMirrorPad2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPad2(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
@ -1569,7 +1628,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMirrorPadSymmetric() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMirrorPadSymmetric(Nd4jBackend backend) {
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4);
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT);
@ -1596,7 +1657,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testUnique() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUnique(Nd4jBackend backend) {
INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4});
INDArray expUnique = Nd4j.create(new double[]{3, 4, 1, 0, 2});
@ -1618,7 +1681,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testTopK() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopK(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //Can't assume sorted here
INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8});
@ -1647,7 +1712,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testTopK1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopK1(Nd4jBackend backend) {
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
INDArray k = Nd4j.scalar(1);
INDArray outValue = Nd4j.create(DataType.DOUBLE, 1);
@ -1668,7 +1735,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testInTopK() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInTopK(Nd4jBackend backend) {
for (int k = 4; k >= 1; k--) {
log.info("Testing: k=" + k);
INDArray in = Nd4j.linspace(1, 20, 20, DataType.DOUBLE).reshape(4, 5);
@ -1709,7 +1778,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testZeta() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testZeta(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182
INDArray x = Nd4j.rand(3, 4).addi(1.0);
INDArray q = Nd4j.rand(3, 4);
@ -1726,7 +1797,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMaxEmptyScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaxEmptyScalar(Nd4jBackend backend) {
INDArray empty = Nd4j.empty(DataType.FLOAT);
INDArray scalar = Nd4j.scalar(1.0f);
@ -1743,7 +1816,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testBroadcastEmpty() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastEmpty(Nd4jBackend backend) {
// Nd4j.getExecutioner().enableVerboseMode(true);
// Nd4j.getExecutioner().enableDebugMode(true);
//Check broadcast behaviour with empty arrays. The idea is to match TF import behaviour, for import
@ -1833,7 +1908,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testStandardize() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardize(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4});
final int[] axis = new int[]{1};
@ -1854,7 +1931,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testStandardizeOP() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardizeOP(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4});
final int[] axis = new int[]{1};
@ -1869,7 +1948,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testStandardizeNoDeviation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStandardizeNoDeviation(Nd4jBackend backend) {
final INDArray random = Nd4j.rand(new int[]{10, 4});
for (int i = 0; i < 4; i++) {
random.putScalar(1, i, 7);
@ -1895,7 +1976,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMatMulTensor() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatMulTensor(Nd4jBackend backend) {
final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5});
final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6});
@ -1915,7 +1998,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMatMulTensorTranspose() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatMulTensorTranspose(Nd4jBackend backend) {
for (boolean transposeA : new boolean[]{false, true}) {
for (boolean transposeB : new boolean[]{false, true}) {
for (boolean transposeResult : new boolean[]{false, true}) {
@ -2008,7 +2093,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testSoftmaxCF() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxCF(Nd4jBackend backend) {
INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5);
INDArray arrF = arrC.dup('f');
@ -2029,7 +2116,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testLogSumExp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogSumExp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4);
SameDiff sd = SameDiff.create();
@ -2044,7 +2133,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testLogSumExp2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogSumExp2(Nd4jBackend backend) {
for (int dim = 0; dim <= 2; dim++) {
Nd4j.getRandom().setSeed(12345);
@ -2065,7 +2156,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testCRELU() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCRELU(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2);
@ -2084,7 +2177,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testClipByAvgNorm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testClipByAvgNorm(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2, 2);
@ -2105,7 +2200,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testEmbeddingLookup() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmbeddingLookup(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2118,49 +2215,53 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testImageResize() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testImageResize(Nd4jBackend backend) {
//TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea
for (ImageResizeMethod method : ImageResizeMethod.values()) {
if (method==ImageResizeMethod.ResizeLanczos5 || method==ImageResizeMethod.ResizeArea || method == ImageResizeMethod.ResizeMitchelcubic)
{continue;}
if (method==ImageResizeMethod.ResizeLanczos5 || method==ImageResizeMethod.ResizeArea || method == ImageResizeMethod.ResizeMitchelcubic)
{continue;}
log.info("Trying {}", method);
log.info("Trying {}", method);
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
boolean preserveAspectRatio = true;
boolean antialias = true;
SDVariable inputImage = sd.var(Nd4j.rand(DataType.FLOAT, 1, 5, 5, 3));
// NHWC format
long[] expectedShape = new long[]{1, 3, 3, 3};
SDVariable requestedSize = sd.constant(Nd4j.createFromArray( new long[]{3, 3}));
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
boolean preserveAspectRatio = true;
boolean antialias = true;
SDVariable inputImage = sd.var(Nd4j.rand(DataType.FLOAT, 1, 5, 5, 3));
// NHWC format
long[] expectedShape = new long[]{1, 3, 3, 3};
SDVariable requestedSize = sd.constant(Nd4j.createFromArray( new long[]{3, 3}));
Function<INDArray, String> checkFunction = in -> {
boolean shapeOk = Arrays.equals(expectedShape, in.shape());
if (shapeOk) return null;
return "Failed: shape differs - expected " + Arrays.toString(expectedShape) + " vs " + Arrays.toString(in.shape()) + " on method " + method;
};
Function<INDArray, String> checkFunction = in -> {
boolean shapeOk = Arrays.equals(expectedShape, in.shape());
if (shapeOk) return null;
return "Failed: shape differs - expected " + Arrays.toString(expectedShape) + " vs " + Arrays.toString(in.shape()) + " on method " + method;
};
SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable().std(true);
SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable().std(true);
String err = OpValidation.validate(new TestCase(sd)
.gradientCheck(false)
.expected("image_resize", checkFunction));
String err = OpValidation.validate(new TestCase(sd)
.gradientCheck(false)
.expected("image_resize", checkFunction));
assertNull(err);
}
}
}
@Test
public void testMaximumBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMaximumBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2177,7 +2278,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMergeAddBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeAddBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2194,7 +2297,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testMergeMaxBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeMaxBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2212,7 +2317,9 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testMergeAvgBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMergeAvgBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2229,7 +2336,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testReverseBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReverseBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
@ -2243,7 +2352,9 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testUpsampling3dBp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUpsampling3dBp(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
for (boolean dataformat : new boolean[]{true, false}) {

View File

@ -24,8 +24,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
@ -36,11 +37,8 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.factory.Nd4jBackend;
public class ConvConfigTests extends BaseNd4jTest {
public class ConvConfigTests extends BaseNd4jTestWithBackends {
public ConvConfigTests(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -48,7 +46,9 @@ public class ConvConfigTests extends BaseNd4jTest {
}
@Test
public void testDeConv2D(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv2D(Nd4jBackend backend){
DeConv2DConfig.builder().kH(2).kW(4).build();
try{
@ -108,8 +108,10 @@ public class ConvConfigTests extends BaseNd4jTest {
}
}
@Test
public void testConv2D(){
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv2D(Nd4jBackend backend){
Conv2DConfig.builder().kH(2).kW(4).build();
try{
@ -169,8 +171,10 @@ public class ConvConfigTests extends BaseNd4jTest {
}
}
@Test
public void testPooling2D(){
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPooling2D(Nd4jBackend backend){
Pooling2DConfig.builder().kH(2).kW(4).build();
try{
@ -230,8 +234,10 @@ public class ConvConfigTests extends BaseNd4jTest {
}
}
@Test
public void testDeConv3D(){
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDeConv3D(Nd4jBackend backend){
DeConv3DConfig.builder().kH(2).kW(4).kD(3).build();
try{
@ -319,8 +325,10 @@ public class ConvConfigTests extends BaseNd4jTest {
}
}
@Test
public void testConv3D(){
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv3D(Nd4jBackend backend){
Conv3DConfig.builder().kH(2).kW(4).kD(3).build();
try{
@ -410,8 +418,10 @@ public class ConvConfigTests extends BaseNd4jTest {
@Test
public void testPooling3D(){
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPooling3D(Nd4jBackend backend){
Pooling3DConfig.builder().kH(2).kW(4).kD(3).build();
try{
@ -499,7 +509,9 @@ public class ConvConfigTests extends BaseNd4jTest {
}
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConv1D(){
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();

View File

@ -23,8 +23,10 @@ package org.nd4j.autodiff.samediff;
import lombok.val;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Disabled("AB 2019/05/21 - JVM Crash on ppc64 - Issue #7657")
public class FailingSameDiffTests extends BaseNd4jTest {
public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
public FailingSameDiffTests(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -52,7 +51,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}
@Test
public void testEye(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEye(Nd4jBackend backend){
//OpValidationSuite.ignoreFailing();
INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3});
List<INDArray> stack = new ArrayList<>();
@ -68,7 +69,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}
@Test
public void testEyeShape(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEyeShape(Nd4jBackend backend){
val dco = DynamicCustomOp.builder("eye")
.addIntegerArguments(3,3)
//.addIntegerArguments(-99,3,3) //Also fails
@ -80,7 +83,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}
@Test
public void testExecutionDifferentShapesTransform(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecutionDifferentShapesTransform(Nd4jBackend backend){
OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4));
@ -101,7 +106,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}
@Test
public void testDropout() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDropout(Nd4jBackend backend) {
OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create();
double p = 0.5;
@ -114,7 +121,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}
@Test
public void testExecutionDifferentShapesDynamicCustom(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecutionDifferentShapesDynamicCustom(Nd4jBackend backend){
OpValidationSuite.ignoreFailing();
SameDiff sd = SameDiff.create();

View File

@ -26,13 +26,15 @@ import org.apache.commons.io.IOUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.graph.FlatConfiguration;
import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
@ -70,11 +72,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
public class FlatBufferSerdeTest extends BaseNd4jTest {
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
public FlatBufferSerdeTest(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -84,7 +83,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test
public void testBasic(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create();
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
@ -139,7 +140,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
}
@Test
public void testSimple(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
for( int i = 0; i < 10; i++ ) {
for(boolean execFirst : new boolean[]{false, true}) {
log.info("Starting test: i={}, execFirst={}", i, execFirst);
@ -268,7 +271,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test
public void testTrainingSerde(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
//Ensure 2 things:
//1. Training config is serialized/deserialized correctly
@ -352,7 +357,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
@Test
public void pooling3DSerialization(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void pooling3DSerialization(Nd4jBackend backend){
SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28);
@ -372,7 +379,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
}
@Test
public void pooling3DSerialization2(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void pooling3DSerialization2(Nd4jBackend backend){
SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28);

View File

@ -22,12 +22,14 @@ package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
import org.nd4j.autodiff.samediff.transform.OpPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraph;
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
public class GraphTransformUtilTests extends BaseNd4jTest {
public class GraphTransformUtilTests extends BaseNd4jTestWithBackends {
public GraphTransformUtilTests(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -54,7 +53,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
}
@Test
public void testBasic(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasic(Nd4jBackend backend){
SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 32);
@ -93,7 +94,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
}
@Test
public void testSubgraphReplace1(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSubgraphReplace1(Nd4jBackend backend){
SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 4);

View File

@ -21,8 +21,10 @@
package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -32,11 +34,8 @@ import java.lang.reflect.Field;
import static org.junit.jupiter.api.Assertions.*;
public class MemoryMgrTest extends BaseNd4jTest {
public class MemoryMgrTest extends BaseNd4jTestWithBackends {
public MemoryMgrTest(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -44,7 +43,9 @@ public class MemoryMgrTest extends BaseNd4jTest {
}
@Test
public void testArrayReuseTooLarge() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArrayReuseTooLarge(Nd4jBackend backend) throws Exception {
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
Field f = ArrayCacheMemoryMgr.class.getDeclaredField("maxCacheBytes");
@ -97,7 +98,7 @@ public class MemoryMgrTest extends BaseNd4jTest {
assertEquals(10, mmgr.getLruCacheValues().size());
//now, allocate some values:
for( int i=1; i<=10; i++ ) {
for( int i = 1; i <= 10; i++) {
INDArray a1 = mmgr.allocate(true, DataType.FLOAT, 25);
assertEquals(1000 - i * 100, mmgr.getCurrentCacheSize());
assertEquals(1000 - i * 100, as.getBytesSum());
@ -116,10 +117,12 @@ public class MemoryMgrTest extends BaseNd4jTest {
}
@Test
public void testManyArrays(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testManyArrays(Nd4jBackend backend){
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
for( int i=0; i<1000; i++ ){
for( int i = 0; i < 1000; i++) {
mmgr.release(Nd4j.scalar(0));
}
@ -127,7 +130,7 @@ public class MemoryMgrTest extends BaseNd4jTest {
assertEquals(1000, mmgr.getLruCache().size());
assertEquals(1000, mmgr.getLruCacheValues().size());
for( int i=0; i<1000; i++ ){
for( int i = 0; i < 1000; i++ ){
mmgr.release(Nd4j.scalar(0));
}

View File

@ -21,9 +21,11 @@
package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,19 +37,18 @@ import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class NameScopeTests extends BaseNd4jTest {
public class NameScopeTests extends BaseNd4jTestWithBackends {
public NameScopeTests(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
public char ordering() {
return 'c';
}
@Test
public void testVariableNameScopesBasic(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVariableNameScopesBasic(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable v = sd.var("x");
@ -73,7 +74,9 @@ public class NameScopeTests extends BaseNd4jTest {
}
@Test
public void testOpFieldsAndNames(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOpFieldsAndNames(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable x = sd.var("x", DataType.FLOAT, 1);
@ -151,7 +154,9 @@ public class NameScopeTests extends BaseNd4jTest {
}
@Test
public void testNoNesting(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoNesting(Nd4jBackend backend) {
SameDiff SD = SameDiff.create();
SDVariable a = SD.constant(4);
@ -168,7 +173,9 @@ public class NameScopeTests extends BaseNd4jTest {
}
@Test
public void testNoTesting2(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoTesting2(Nd4jBackend backend) {
SameDiff SD = SameDiff.create();
SDVariable a = SD.constant(4);

View File

@ -21,21 +21,16 @@
package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.imports.tfgraphs.TFGraphTestZooModels;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.io.File;
import java.nio.file.Path;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
@ -55,7 +50,9 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
}
@Test
public void testSimple() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(Nd4jBackend backend) throws Exception {
int nThreads = 4;
int nRuns = 1000;
@ -103,48 +100,6 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
}
}
@Test
@Disabled //2020/03/24 AB - https://github.com/eclipse/deeplearning4j/issues/8802
public void testMobilenet(@TempDir Path testDir) throws Exception {
TFGraphTestZooModels.currentTestDir = testDir.toFile();
File f = Resources.asFile("tf_graphs/zoo_models/mobilenet_v2_1.0_224/tf_model.txt");
SameDiff sd = TFGraphTestZooModels.LOADER.apply(f, "mobilenet_v2_1.0_224");
// System.out.println(sd.summary());
int nThreads = 4;
int nRuns = 30;
INDArray[] inputArrs = new INDArray[nThreads];
INDArray[] expOut = new INDArray[nThreads];
for( int i=0; i<nThreads; i++ ){
if(i == 0 || i > 2)
inputArrs[i] = Nd4j.rand(DataType.FLOAT, 1, 224, 224, 3);
else if(i == 1)
inputArrs[i] = Nd4j.zeros(DataType.FLOAT, 1, 224, 224, 3);
else if(i == 2)
inputArrs[i] = Nd4j.ones(DataType.FLOAT, 1, 224, 224, 3);
expOut[i] = sd.outputSingle(Collections.singletonMap("input", inputArrs[i]), "MobilenetV2/Predictions/Reshape_1");
Nd4j.getExecutioner().commit();
}
AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads];
AtomicInteger[] counters = new AtomicInteger[nThreads];
Semaphore s = new Semaphore(nThreads);
CountDownLatch latch = new CountDownLatch(nThreads);
doTest(sd, nThreads, nRuns, inputArrs, expOut, "input", "MobilenetV2/Predictions/Reshape_1", failuresByThread, counters, s, latch);
s.release(nThreads);
latch.await();
for(int i = 0; i < nThreads; i++) {
assertFalse( failuresByThread[i].get(),"Thread " + i + " failed");
}
for(int i = 0; i < nThreads; i++) {
assertEquals( nRuns, counters[i].get(),"Thread " + i + " number of runs");
}
}
public static void doTest(SameDiff sd, int nThreads, int nRuns, INDArray[] inputArrs, INDArray[] expOut,

View File

@ -21,7 +21,9 @@
package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -31,14 +33,13 @@ import org.nd4j.linalg.learning.config.Sgd;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class SameDiffOutputTest extends BaseNd4jTest {
public class SameDiffOutputTest extends BaseNd4jTestWithBackends {
public SameDiffOutputTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void outputTest(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void outputTest(Nd4jBackend backend){
DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10));
SameDiff sd = SameDiff.create();

View File

@ -21,7 +21,9 @@
package org.nd4j.autodiff.samediff;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -35,19 +37,18 @@ import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.assertNull;
import static org.junit.jupiter.api.Assertions.*;
public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
public SameDiffSpecifiedLossVarsTests(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
public char ordering() {
return 'c';
}
@Test
public void testSpecifiedLoss1(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedLoss1(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4);
ph1.setArray(Nd4j.create(DataType.FLOAT, 3, 4));
@ -68,7 +69,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
}
@Test
public void testSpecifiedLoss2(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedLoss2(Nd4jBackend backend) {
for( int i=0; i<2; i++ ) {
SameDiff sd = SameDiff.create();
SDVariable ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4);
@ -121,7 +124,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
@Test
public void testTrainingDifferentLosses(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingDifferentLosses(Nd4jBackend backend) {
//Net with 2 losses: train on the first one, then change losses
//Also check that if modifying via add/setLossVariables the training config changes

View File

@ -30,11 +30,13 @@ import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.impl.ScoreListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -55,14 +57,13 @@ import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.weightinit.impl.XavierInitScheme;
@Slf4j
public class SameDiffTrainingTest extends BaseNd4jTest {
public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
public SameDiffTrainingTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void irisTrainingSanityCheck() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingSanityCheck(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize();
@ -134,7 +135,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test
public void irisTrainingEvalTest() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingEvalTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize();
@ -184,7 +187,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test
public void irisTrainingValidationTest() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisTrainingValidationTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize();
@ -239,6 +244,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingMixedDtypes(){
for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) {
@ -301,7 +308,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
}
@Test
public void simpleClassification() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void simpleClassification(Nd4jBackend backend) {
double learning_rate = 0.001;
int seed = 7;
org.nd4j.linalg.api.rng.Random rng = Nd4j.getRandom();
@ -348,6 +357,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTrainingEvalVarNotReqForLoss(){
//If a variable is not required for the loss - normally it won't be calculated
//But we want to make sure it IS calculated here - so we can perform evaluation on it

View File

@ -25,11 +25,13 @@ import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.IrisDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -48,11 +50,8 @@ import java.util.concurrent.TimeUnit;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class CheckpointListenerTest extends BaseNd4jTest {
public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
public CheckpointListenerTest(Nd4jBackend backend){
super(backend);
}
@Override
public char ordering(){
@ -96,7 +95,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
@Test
public void testCheckpointEveryEpoch(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();
@ -130,7 +131,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
}
@Test
public void testCheckpointEvery5Iter(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();
@ -169,7 +172,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
@Test
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();
@ -199,7 +204,7 @@ public class CheckpointListenerTest extends BaseNd4jTest {
for(File f : files){
String s = f.getAbsolutePath();
// System.out.println(s);
for( int i=0; i<names.size(); i++ ){
for( int i = 0; i < names.size(); i++ ){
if(s.contains(names.get(i))){
found[i] = true;
break;
@ -213,7 +218,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
}
@Test
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();

View File

@ -21,12 +21,13 @@
package org.nd4j.autodiff.samediff.listeners;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -34,14 +35,13 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.Adam;
public class ExecDebuggingListenerTest extends BaseNd4jTest {
public class ExecDebuggingListenerTest extends BaseNd4jTestWithBackends {
public ExecDebuggingListenerTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void testExecDebugListener(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExecDebugListener(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);

View File

@ -21,6 +21,8 @@
package org.nd4j.autodiff.samediff.listeners;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Listener;
@ -38,7 +40,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
@ -61,11 +63,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
public class ListenerTest extends BaseNd4jTest {
public class ListenerTest extends BaseNd4jTestWithBackends {
public ListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -73,7 +72,9 @@ public class ListenerTest extends BaseNd4jTest {
}
@Test
public void irisHistoryTest() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void irisHistoryTest(Nd4jBackend backend) {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize();
@ -136,6 +137,8 @@ public class ListenerTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testListenerCalls(){
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
@ -273,7 +276,9 @@ public class ListenerTest extends BaseNd4jTest {
}
@Test
public void testCustomListener() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCustomListener(Nd4jBackend backend) {
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);

View File

@ -26,11 +26,13 @@ import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -45,11 +47,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
public class ProfilingListenerTest extends BaseNd4jTest {
public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
public ProfilingListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -59,7 +58,9 @@ public class ProfilingListenerTest extends BaseNd4jTest {
@Test
public void testProfilingListenerSimple(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
@ -107,19 +108,25 @@ public class ProfilingListenerTest extends BaseNd4jTest {
}
/*
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfile(){
File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json");
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfileDir(){
File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLoadTfProfileDir2(){
File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0");
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);

View File

@ -27,6 +27,8 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
@ -41,7 +43,7 @@ import org.nd4j.graph.UIInfoType;
import org.nd4j.graph.UIOp;
import org.nd4j.graph.UIVariable;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -60,11 +62,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
public class FileReadWriteTests extends BaseNd4jTest {
public class FileReadWriteTests extends BaseNd4jTestWithBackends {
public FileReadWriteTests(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -81,7 +80,9 @@ public class FileReadWriteTests extends BaseNd4jTest {
}
@Test
public void testSimple(@TempDir Path testDir) throws IOException {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
SameDiff sd = SameDiff.create();
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
SDVariable sum = v.sum();
@ -185,7 +186,9 @@ public class FileReadWriteTests extends BaseNd4jTest {
}
@Test
public void testNullBinLabels(@TempDir Path testDir) throws Exception{
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
File dir = testDir.toFile();
File f = new File(dir, "temp.bin");
LogFileWriter w = new LogFileWriter(f);

View File

@ -25,6 +25,8 @@ import com.google.flatbuffers.Table;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.listeners.impl.UIListener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -34,7 +36,7 @@ import org.nd4j.graph.UIEvent;
import org.nd4j.graph.UIGraphStructure;
import org.nd4j.graph.UIStaticInfoRecord;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.IrisDataSetIterator;
@ -51,11 +53,8 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
public class UIListenerTest extends BaseNd4jTest {
public class UIListenerTest extends BaseNd4jTestWithBackends {
public UIListenerTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -65,7 +64,9 @@ public class UIListenerTest extends BaseNd4jTest {
@Test
public void testUIListenerBasic(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
Nd4j.getRandom().setSeed(12345);
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -101,7 +102,9 @@ public class UIListenerTest extends BaseNd4jTest {
}
@Test
public void testUIListenerContinue(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet();
@ -192,7 +195,9 @@ public class UIListenerTest extends BaseNd4jTest {
}
@Test
public void testUIListenerBadContinue(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet();

View File

@ -23,18 +23,17 @@ package org.nd4j.evaluation;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.custom.CustomEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.primitives.Pair;
public class CustomEvaluationTest extends BaseNd4jTest {
public class CustomEvaluationTest extends BaseNd4jTestWithBackends {
public CustomEvaluationTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -42,8 +41,10 @@ public class CustomEvaluationTest extends BaseNd4jTest {
}
@Test
public void customEvalTest(){
CustomEvaluation accuracyEval = new CustomEvaluation<Pair<Number, Long>>(
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void customEvalTest(Nd4jBackend backend){
CustomEvaluation accuracyEval = new CustomEvaluation<>(
(labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)),
CustomEvaluation.mergeConcatenate());

View File

@ -21,6 +21,8 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -29,25 +31,24 @@ import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
public class EmptyEvaluationTests extends BaseNd4jTest {
public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
public EmptyEvaluationTests(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void testEmptyEvaluation() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluation (Nd4jBackend backend) {
Evaluation e = new Evaluation();
System.out.println(e.stats());
@ -62,7 +63,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyRegressionEvaluation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyRegressionEvaluation (Nd4jBackend backend) {
RegressionEvaluation re = new RegressionEvaluation();
re.stats();
@ -76,7 +79,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyEvaluationBinary() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluationBinary(Nd4jBackend backend) {
EvaluationBinary eb = new EvaluationBinary();
eb.stats();
@ -91,7 +96,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyROC() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROC(Nd4jBackend backend) {
ROC roc = new ROC();
roc.stats();
@ -106,7 +113,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyROCBinary() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROCBinary(Nd4jBackend backend) {
ROCBinary rb = new ROCBinary();
rb.stats();
@ -121,7 +130,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyROCMultiClass() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyROCMultiClass(Nd4jBackend backend) {
ROCMultiClass r = new ROCMultiClass();
r.stats();
@ -136,7 +147,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
}
@Test
public void testEmptyEvaluationCalibration() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEmptyEvaluationCalibration(Nd4jBackend backend) {
EvaluationCalibration ec = new EvaluationCalibration();
ec.stats();

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
@ -36,11 +38,8 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class EvalCustomThreshold extends BaseNd4jTest {
public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
public EvalCustomThreshold(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -48,7 +47,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
}
@Test
public void testEvaluationCustomBinaryThreshold() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCustomBinaryThreshold(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
//Sanity checks: 0.5 threshold for 1-output and 2-output binary cases
@ -114,7 +115,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
}
@Test
public void testEvaluationCostArray() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCostArray(Nd4jBackend backend) {
int nExamples = 20;
@ -162,7 +165,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
}
@Test
public void testEvaluationBinaryCustomThreshold() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryCustomThreshold(Nd4jBackend backend) {
//Sanity check: same results for 0.5 threshold vs. default (no threshold)
int nExamples = 20;

View File

@ -21,6 +21,8 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -31,7 +33,7 @@ import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class EvalJsonTest extends BaseNd4jTest {
public class EvalJsonTest extends BaseNd4jTestWithBackends {
public EvalJsonTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -54,7 +53,9 @@ public class EvalJsonTest extends BaseNd4jTest {
}
@Test
public void testSerdeEmpty() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerdeEmpty(Nd4jBackend backend) {
boolean print = false;
IEvaluation[] arr = new IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10),
@ -73,8 +74,10 @@ public class EvalJsonTest extends BaseNd4jTest {
}
}
@Test
public void testSerde() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerde(Nd4jBackend backend) {
boolean print = false;
Nd4j.getRandom().setSeed(12345);
@ -121,8 +124,10 @@ public class EvalJsonTest extends BaseNd4jTest {
}
}
@Test
public void testSerdeExactRoc() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerdeExactRoc(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
boolean print = false;
@ -199,8 +204,10 @@ public class EvalJsonTest extends BaseNd4jTest {
}
}
@Test
public void testJsonYamlCurves() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJsonYamlCurves(Nd4jBackend backend) {
ROC roc = new ROC(0);
INDArray evalLabel =
@ -251,8 +258,10 @@ public class EvalJsonTest extends BaseNd4jTest {
}
@Test
public void testJsonWithCustomThreshold() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJsonWithCustomThreshold(Nd4jBackend backend) {
//Evaluation - binary threshold
Evaluation e = new Evaluation(0.25);

View File

@ -21,8 +21,10 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -39,11 +41,8 @@ import static org.junit.jupiter.api.Assertions.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
public class EvalTest extends BaseNd4jTest {
public class EvalTest extends BaseNd4jTestWithBackends {
public EvalTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -52,7 +51,9 @@ public class EvalTest extends BaseNd4jTest {
@Test
public void testEval() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEval(Nd4jBackend backend) {
int classNum = 5;
Evaluation eval = new Evaluation (classNum);
@ -91,7 +92,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testEval2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEval2(Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
Evaluation first = null;
@ -150,7 +153,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testStringListLabels() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStringListLabels(Nd4jBackend backend) {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
@ -167,7 +172,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testStringHashLabels() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testStringHashLabels(Nd4jBackend backend) {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
@ -184,7 +191,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testEvalMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalMasking(Nd4jBackend backend) {
int miniBatch = 5;
int nOut = 3;
int tsLength = 6;
@ -251,7 +260,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testFalsePerfectRecall() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFalsePerfectRecall(Nd4jBackend backend) {
int testSize = 100;
int numClasses = 5;
int winner = 1;
@ -284,7 +295,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testEvaluationMerging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationMerging(Nd4jBackend backend) {
int nRows = 20;
int nCols = 3;
@ -358,7 +371,9 @@ public class EvalTest extends BaseNd4jTest {
@Test
public void testSingleClassBinaryClassification() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleClassBinaryClassification(Nd4jBackend backend) {
Evaluation eval = new Evaluation(1);
@ -387,7 +402,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testEvalInvalid() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalInvalid(Nd4jBackend backend) {
Evaluation e = new Evaluation(5);
e.eval(0, 1);
e.eval(1, 0);
@ -400,7 +417,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testEvalMethods() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalMethods(Nd4jBackend backend) {
//Check eval(int,int) vs. eval(INDArray,INDArray)
Evaluation e1 = new Evaluation(4);
@ -443,7 +462,9 @@ public class EvalTest extends BaseNd4jTest {
@Test
public void testTopNAccuracy() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopNAccuracy(Nd4jBackend backend) {
Evaluation e = new Evaluation(null, 3);
@ -504,7 +525,9 @@ public class EvalTest extends BaseNd4jTest {
@Test
public void testTopNAccuracyMerging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTopNAccuracyMerging(Nd4jBackend backend) {
Evaluation e1 = new Evaluation(null, 3);
Evaluation e2 = new Evaluation(null, 3);
@ -552,7 +575,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testBinaryCase() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBinaryCase(Nd4jBackend backend) {
INDArray ones10 = Nd4j.ones(10, 1);
INDArray ones4 = Nd4j.ones(4, 1);
INDArray zeros4 = Nd4j.zeros(4, 1);
@ -581,7 +606,9 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
public void testF1FBeta_MicroMacroAveraging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testF1FBeta_MicroMacroAveraging(Nd4jBackend backend) {
//Confusion matrix: rows = actual, columns = predicted
//[3, 1, 0]
//[2, 2, 1]
@ -722,7 +749,9 @@ public class EvalTest extends BaseNd4jTest {
@Test
public void testConfusionMatrixStats() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrixStats(Nd4jBackend backend) {
Evaluation e = new Evaluation();
@ -743,6 +772,8 @@ public class EvalTest extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalBinaryMetrics(){
Evaluation ePosClass1_nOut2 = new Evaluation(2, 1);
@ -864,6 +895,8 @@ public class EvalTest extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConfusionMatrixString(){
Evaluation e = new Evaluation(Arrays.asList("a","b","c"));
@ -914,6 +947,8 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationNaNs(){
Evaluation e = new Evaluation();
@ -929,6 +964,8 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345);
@ -1023,6 +1060,8 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLabelReset(){
Map<Integer,String> m = new HashMap<>();
@ -1056,6 +1095,8 @@ public class EvalTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalStatsBinaryCase(){
//Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,11 +40,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.nd4j.evaluation.classification.EvaluationBinary.Metric.*;
public class EvaluationBinaryTest extends BaseNd4jTest {
public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
public EvaluationBinaryTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -50,7 +49,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinary() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary(Nd4jBackend backend) {
//Compare EvaluationBinary to Evaluation class
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
EvaluationBinary first = null;
@ -136,7 +137,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinaryMerging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryMerging(Nd4jBackend backend) {
int nOut = 4;
int[] shape1 = {30, nOut};
int[] shape2 = {50, nOut};
@ -163,7 +166,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinaryPerOutputMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryPerOutputMasking(Nd4jBackend backend) {
//Provide a mask array: "ignore" the masked steps
@ -172,7 +177,7 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
INDArray labels = Nd4j.create(new double[][] {{1, 1, 1}, {0, 0, 0}, {1, 1, 1}, {0, 1, 1}, {1, 0, 1}});
INDArray predicted = Nd4j.create(new double[][] {{0.9, 0.9, 0.9}, {0.7, 0.7, 0.7}, {0.6, 0.6, 0.6},
{0.4, 0.4, 0.4}, {0.1, 0.1, 0.1}});
{0.4, 0.4, 0.4}, {0.1, 0.1, 0.1}});
//Correct?
// Y Y m
@ -206,7 +211,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testTimeSeriesEval() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTimeSeriesEval(Nd4jBackend backend) {
int[] shape = {2, 4, 3};
Nd4j.getRandom().setSeed(12345);
@ -230,12 +237,14 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinaryWithROC() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinaryWithROC(Nd4jBackend backend) {
//Simple test for nested ROCBinary in EvaluationBinary
Nd4j.getRandom().setSeed(12345);
INDArray l1 = Nd4j.getExecutioner()
.exec(new BernoulliDistribution(Nd4j.createUninitialized(new int[] {50, 4}), 0.5));
.exec(new BernoulliDistribution(Nd4j.createUninitialized(new int[] {50, 4}), 0.5));
INDArray p1 = Nd4j.rand(50, 4);
EvaluationBinary eb = new EvaluationBinary(4, 30);
@ -247,7 +256,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
@Test
public void testEvaluationBinary3d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -281,7 +292,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinary4d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -315,7 +328,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinary3dMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -376,7 +391,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
}
@Test
public void testEvaluationBinary4dMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -21,8 +21,10 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,19 +41,18 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
public class EvaluationCalibrationTest extends BaseNd4jTest {
public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
public EvaluationCalibrationTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
public char ordering () {
return 'c';
}
@Test
public void testReliabilityDiagram() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReliabilityDiagram (Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
EvaluationCalibration first = null;
@ -142,8 +143,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
}
}
@Test
public void testLabelAndPredictionCounts() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLabelAndPredictionCounts (Nd4jBackend backend) {
int minibatch = 50;
int nClasses = 3;
@ -170,8 +173,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
assertArrayEquals(expPredictionCount, ec.getPredictionCountsEachClass());
}
@Test
public void testResidualPlots() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResidualPlots (Nd4jBackend backend) {
int minibatch = 50;
int nClasses = 3;
@ -271,7 +276,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
}
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345);
@ -365,8 +372,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
}
}
@Test
public void testEvaluationCalibration3d() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCalibration3d (Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -397,8 +406,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
assertEquals(e2d.stats(), e3d.stats());
}
@Test
public void testEvaluationCalibration3dMasking() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvaluationCalibration3dMasking (Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);

View File

@ -23,6 +23,8 @@ package org.nd4j.evaluation;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
@ -30,17 +32,14 @@ import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
public class NewInstanceTest extends BaseNd4jTest {
public class NewInstanceTest extends BaseNd4jTestWithBackends {
public NewInstanceTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -48,7 +47,9 @@ public class NewInstanceTest extends BaseNd4jTest {
}
@Test
public void testNewInstances() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewInstances(Nd4jBackend backend) {
boolean print = true;
Nd4j.getRandom().setSeed(12345);

View File

@ -21,10 +21,12 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,19 +41,17 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class ROCBinaryTest extends BaseNd4jTest {
public ROCBinaryTest(Nd4jBackend backend) {
super(backend);
}
public class ROCBinaryTest extends BaseNd4jTestWithBackends {
@Override
public char ordering() {
return 'c';
}
@Test
public void testROCBinary() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary(Nd4jBackend backend) {
//Compare ROCBinary to ROC class
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
@ -145,8 +145,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
}
}
@Test
public void testRocBinaryMerging() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBinaryMerging(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact
int nOut = 4;
int[] shape1 = {30, nOut};
@ -175,8 +177,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
}
@Test
public void testROCBinaryPerOutputMasking() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact
@ -215,8 +219,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
@Test
public void testROCBinary3d() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -249,8 +255,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
}
}
@Test
public void testROCBinary4d() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -283,8 +291,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
}
}
@Test
public void testROCBinary3dMasking() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -344,8 +354,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
}
}
@Test
public void testROCBinary4dMasking() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -21,12 +21,14 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -39,11 +41,8 @@ import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
public class ROCTest extends BaseNd4jTest {
public class ROCTest extends BaseNd4jTestWithBackends {
public ROCTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -83,8 +82,10 @@ public class ROCTest extends BaseNd4jTest {
expFPR.put(10 / 10.0, 0.0 / totalNegatives);
}
@Test
public void testRocBasic() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBasic(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax)
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
@ -126,8 +127,10 @@ public class ROCTest extends BaseNd4jTest {
assertEquals(1.0, auc, 1e-6);
}
@Test
public void testRocBasicSingleClass() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBasicSingleClass(Nd4jBackend backend) {
//1 output here - single probability value (sigmoid)
//add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
@ -164,8 +167,10 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testRoc() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRoc(Nd4jBackend backend) {
//Previous tests allowed for a perfect classifier with right threshold...
INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});
@ -249,8 +254,10 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testRocTimeSeriesNoMasking() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
//Same as first test...
//2 outputs here - probability distribution over classes (softmax)
@ -296,8 +303,10 @@ public class ROCTest extends BaseNd4jTest {
}
}
@Test
public void testRocTimeSeriesMasking() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax)
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
@ -346,8 +355,10 @@ public class ROCTest extends BaseNd4jTest {
@Test
public void testCompareRocAndRocMultiClass() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
//For 2 class case: ROC and Multi-class ROC should be the same...
@ -376,8 +387,10 @@ public class ROCTest extends BaseNd4jTest {
}
}
@Test
public void testCompare2Vs3Classes() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCompare2Vs3Classes(Nd4jBackend backend) {
//ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together...
//Both methods implement one vs. all ROC/AUC in different ways
@ -425,8 +438,10 @@ public class ROCTest extends BaseNd4jTest {
}
}
@Test
public void testROCMerging() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMerging(Nd4jBackend backend) {
int nArrays = 10;
int minibatch = 64;
int nROCs = 3;
@ -470,8 +485,10 @@ public class ROCTest extends BaseNd4jTest {
}
}
@Test
public void testROCMerging2() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMerging2(Nd4jBackend backend) {
int nArrays = 10;
int minibatch = 64;
int exactAllocBlockSize = 10;
@ -515,8 +532,10 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testROCMultiMerging() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testROCMultiMerging(Nd4jBackend backend) {
int nArrays = 10;
int minibatch = 64;
@ -563,8 +582,10 @@ public class ROCTest extends BaseNd4jTest {
}
}
@Test
public void testAUCPrecisionRecall() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAUCPrecisionRecall(Nd4jBackend backend) {
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
//at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0
//at threshold 0.25 to 0.33: tp=2, fp=0, fn=0, tn=1 prec=2/2=1, recall=2/2=1
@ -610,8 +631,10 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testRocAucExact() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocAucExact(Nd4jBackend backend) {
//Check the implementation vs. Scikitlearn
/*
@ -773,8 +796,10 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void rocExactEdgeCaseReallocation() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
//Set reallocation block size to say 20, but then evaluate a 100-length array
@ -785,8 +810,10 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
public void testPrecisionRecallCurveGetPointMethods() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
double[] threshold = new double[101];
double[] precision = threshold;
double[] recall = new double[101];
@ -821,8 +848,10 @@ public class ROCTest extends BaseNd4jTest {
}
}
@Test
public void testPrecisionRecallCurveConfusion() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
//Sanity check: values calculated from the confusion matrix should match the PR curve values
for (boolean removeRedundantPts : new boolean[] {true, false}) {
@ -860,7 +889,9 @@ public class ROCTest extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocMerge(){
Nd4j.getRandom().setSeed(12345);
@ -904,7 +935,9 @@ public class ROCTest extends BaseNd4jTest {
assertEquals(auprc, auprcAct, 1e-6);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocMultiMerge(){
Nd4j.getRandom().setSeed(12345);
@ -953,7 +986,9 @@ public class ROCTest extends BaseNd4jTest {
}
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRocBinaryMerge(){
Nd4j.getRandom().setSeed(12345);
@ -998,7 +1033,9 @@ public class ROCTest extends BaseNd4jTest {
}
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentationBinary(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345);
@ -1088,7 +1125,9 @@ public class ROCTest extends BaseNd4jTest {
}
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345);

View File

@ -21,9 +21,11 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
public class RegressionEvalTest extends BaseNd4jTest {
public class RegressionEvalTest extends BaseNd4jTestWithBackends {
public RegressionEvalTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
@ -52,7 +51,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test()
public void testEvalParameters() {
public void testEvalParameters(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
int specCols = 5;
INDArray labels = Nd4j.ones(3);
@ -65,7 +64,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testPerfectPredictions() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPerfectPredictions(Nd4jBackend backend) {
int nCols = 5;
int nTestArrays = 100;
@ -92,7 +93,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testKnownValues() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testKnownValues(Nd4jBackend backend) {
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
RegressionEvaluation first = null;
@ -148,7 +151,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
@Test
public void testRegressionEvaluationMerging() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvaluationMerging(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
int nRows = 20;
@ -189,7 +194,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testRegressionEvalPerOutputMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvalPerOutputMasking(Nd4jBackend backend) {
INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}});
@ -216,6 +223,8 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEvalTimeSeriesSplit(){
INDArray out1 = Nd4j.rand(new int[]{3, 5, 20});
@ -238,7 +247,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testRegressionEval3d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -270,7 +281,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testRegressionEval4d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -302,7 +315,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testRegressionEval3dMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -361,7 +376,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
}
@Test
public void testRegressionEval4dMasking() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRegressionEval4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -22,10 +22,12 @@ package org.nd4j.evaluation;
import org.apache.commons.io.FileUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.io.ClassPathResource;
@ -34,11 +36,8 @@ import java.nio.charset.StandardCharsets;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestLegacyJsonLoading extends BaseNd4jTest {
public class TestLegacyJsonLoading extends BaseNd4jTestWithBackends {
public TestLegacyJsonLoading(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -46,7 +45,9 @@ public class TestLegacyJsonLoading extends BaseNd4jTest {
}
@Test
public void testEvalLegacyFormat() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEvalLegacyFormat(Nd4jBackend backend) throws Exception {
File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile();
String s = FileUtils.readFileToString(f, StandardCharsets.UTF_8);

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -38,17 +39,14 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@Slf4j
@RunWith(Parameterized.class)
public class AveragingTests extends BaseNd4jTest {
public class AveragingTests extends BaseNd4jTestWithBackends {
private final int THREADS = 16;
private final int LENGTH = 51200 * 4;
DataType initialType;
DataType initialType = Nd4j.dataType();
public AveragingTests(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach
public void setUp() {
@ -63,7 +61,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test
public void testSingleDeviceAveraging1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleDeviceAveraging1(Nd4jBackend backend) {
INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0);
INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0);
INDArray array3 = Nd4j.valueArrayOf(LENGTH, 3.0);
@ -110,7 +110,9 @@ public class AveragingTests extends BaseNd4jTest {
}
@Test
public void testSingleDeviceAveraging2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSingleDeviceAveraging2(Nd4jBackend backend) {
INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH);
List<INDArray> arrays = new ArrayList<>();
for (int i = 0; i < THREADS; i++)
@ -127,7 +129,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test
public void testAccumulation1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation1(Nd4jBackend backend) {
INDArray array1 = Nd4j.create(100).assign(1.0);
INDArray array2 = Nd4j.create(100).assign(2.0);
INDArray array3 = Nd4j.create(100).assign(3.0);
@ -140,7 +144,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test
public void testAccumulation2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation2(Nd4jBackend backend) {
INDArray array1 = Nd4j.create(100).assign(1.0);
INDArray array2 = Nd4j.create(100).assign(2.0);
INDArray array3 = Nd4j.create(100).assign(3.0);
@ -155,7 +161,9 @@ public class AveragingTests extends BaseNd4jTest {
@Test
public void testAccumulation3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAccumulation3(Nd4jBackend backend) {
// we want to ensure that cuda backend is able to launch this op on cpu
Nd4j.getAffinityManager().allowCrossDeviceAccess(false);

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -34,15 +35,14 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@Slf4j
public class DataTypeTest extends BaseNd4jTest {
public DataTypeTest(Nd4jBackend backend) {
super(backend);
}
public class DataTypeTest extends BaseNd4jTestWithBackends {
@Test
public void testDataTypes() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDataTypes(Nd4jBackend backend) throws Exception {
for (val type : DataType.values()) {
if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type))
continue;

View File

@ -21,20 +21,17 @@
package org.nd4j.linalg;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.fail;
@RunWith(Parameterized.class)
public class InputValidationTests extends BaseNd4jTest {
public InputValidationTests(Nd4jBackend backend) {
super(backend);
}
public class InputValidationTests extends BaseNd4jTestWithBackends {
@Override
public char ordering() {
@ -45,7 +42,9 @@ public class InputValidationTests extends BaseNd4jTest {
///////////////////// Broadcast Tests ///////////////////////
@Test
public void testInvalidColVectorOp1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidColVectorOp1(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10);
INDArray col = Nd4j.create(5, 1);
try {
@ -57,7 +56,9 @@ public class InputValidationTests extends BaseNd4jTest {
}
@Test
public void testInvalidColVectorOp2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidColVectorOp2(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10);
INDArray col = Nd4j.create(5, 1);
try {
@ -69,7 +70,9 @@ public class InputValidationTests extends BaseNd4jTest {
}
@Test
public void testInvalidRowVectorOp1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidRowVectorOp1(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10);
INDArray row = Nd4j.create(1, 5);
try {
@ -81,7 +84,9 @@ public class InputValidationTests extends BaseNd4jTest {
}
@Test
public void testInvalidRowVectorOp2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInvalidRowVectorOp2(Nd4jBackend backend) {
INDArray first = Nd4j.create(10, 10);
INDArray row = Nd4j.create(1, 5);
try {

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.RandomUtils;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
@ -47,14 +48,13 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
@RunWith(Parameterized.class)
public class LoneTest extends BaseNd4jTest {
public LoneTest(Nd4jBackend backend) {
super(backend);
}
public class LoneTest extends BaseNd4jTestWithBackends {
@Test
public void testSoftmaxStability() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSoftmaxStability(Nd4jBackend backend) {
INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose();
// System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1);
@ -68,7 +68,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void testFlattenedView() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testFlattenedView(Nd4jBackend backend) {
int rows = 8;
int cols = 8;
int dim2 = 4;
@ -104,7 +106,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void testIndexingColVec() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingColVec(Nd4jBackend backend) {
int elements = 5;
INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements);
INDArray colVector = rowVector.transpose();
@ -123,7 +127,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void concatScalarVectorIssue() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void concatScalarVectorIssue(Nd4jBackend backend) {
//A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars
INDArray arr1 = Nd4j.create(1, 1);
INDArray arr2 = Nd4j.create(1, 8);
@ -133,7 +139,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void reshapeTensorMmul() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void reshapeTensorMmul(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2);
INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2);
int[][] axes = new int[2][];
@ -145,7 +153,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void maskWhenMerge() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void maskWhenMerge(Nd4jBackend backend) {
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
List<DataSet> dataSetList = new ArrayList<DataSet>();
@ -160,7 +170,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void testRelu() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRelu(Nd4jBackend backend) {
INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
INDArray b = Nd4j.getExecutioner().exec(new Tanh(aA));
@ -172,7 +184,7 @@ public class LoneTest extends BaseNd4jTest {
@Test
//broken at a threshold
public void testArgMax() {
public void testArgMax(Nd4jBackend backend) {
int max = 63;
INDArray A = Nd4j.linspace(1, max, max).reshape(1, max);
int currentArgMax = Nd4j.argMax(A).getInt(0);
@ -186,7 +198,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void testRPF() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRPF(Nd4jBackend backend) {
val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3);
log.info("--------");
@ -199,7 +213,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void testConcat3D_Vstack_C() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat3D_Vstack_C(Nd4jBackend backend) {
val shape = new long[]{1, 1000, 20};
List<INDArray> cArrays = new ArrayList<>();
@ -229,7 +245,9 @@ public class LoneTest extends BaseNd4jTest {
@Test
public void testGetRow1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRow1(Nd4jBackend backend) {
INDArray array = Nd4j.create(10000, 10000);
//Thread.sleep(10000);
@ -256,7 +274,7 @@ public class LoneTest extends BaseNd4jTest {
}
@Test()
public void checkIllegalElementOps() {
public void checkIllegalElementOps(Nd4jBackend backend) {
assertThrows(Exception.class,() -> {
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
INDArray B = A.dup().reshape(2, 2, 5);
@ -268,7 +286,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void checkSliceofSlice() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void checkSliceofSlice(Nd4jBackend backend) {
/*
Issue 1: Slice of slice with c order and f order views are not equal
@ -308,7 +328,9 @@ public class LoneTest extends BaseNd4jTest {
}
@Test
public void checkWithReshape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void checkWithReshape(Nd4jBackend backend) {
INDArray arr = Nd4j.create(1, 3);
INDArray reshaped = arr.reshape('f', 3, 1);
for (int i=0;i<reshaped.length();i++) {

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -28,11 +30,8 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class MmulBug extends BaseNd4jTest {
public class MmulBug extends BaseNd4jTestWithBackends {
public MmulBug(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
@ -40,7 +39,9 @@ public class MmulBug extends BaseNd4jTest {
}
@Test
public void simpleTest() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void simpleTest(Nd4jBackend backend) {
INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}});
m1 = m1.reshape(2, 2);

View File

@ -25,8 +25,9 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -58,19 +59,14 @@ import static org.junit.jupiter.api.Assertions.*;
*
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
@Slf4j
public class NDArrayTestsFortran extends BaseNd4jTest {
public NDArrayTestsFortran(Nd4jBackend backend) {
super(backend);
}
public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
@Test
public void testScalarOps() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalarOps(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
assertEquals(27d, n.length(), 1e-1);
n.addi(Nd4j.scalar(1d));
@ -88,7 +84,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testColumnMmul() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumnMmul(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 10, 18, DataType.FLOAT).data();
INDArray x2 = Nd4j.create(data, new long[] {2, 3, 3});
data = Nd4j.linspace(1, 12, 9, DataType.FLOAT).data();
@ -119,7 +117,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testRowVectorGemm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowVectorGemm(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE);
INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE);
INDArray result = linspace.mmul(other);
@ -130,13 +130,17 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testRepmat() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRepmat(Nd4jBackend backend) {
INDArray rowVector = Nd4j.create(1, 4);
INDArray repmat = rowVector.repmat(4, 4);
assertTrue(Arrays.equals(new long[] {4, 16}, repmat.shape()));
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWrite() throws Exception {
INDArray write = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -152,6 +156,8 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWriteDouble() throws Exception {
INDArray write = Nd4j.linspace(1, 4, 4, DataType.FLOAT);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -168,18 +174,17 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMultiThreading() throws Exception {
ExecutorService ex = ExecutorServiceProvider.getExecutorService();
List<Future<?>> list = new ArrayList<>(100);
for (int i = 0; i < 100; i++) {
Future<?> future = ex.submit(new Runnable() {
@Override
public void run() {
INDArray dot = Nd4j.linspace(1, 8, 8, DataType.DOUBLE);
Future<?> future = ex.submit(() -> {
INDArray dot = Nd4j.linspace(1, 8, 8, DataType.DOUBLE);
// System.out.println(Transforms.sigmoid(dot));
Transforms.sigmoid(dot);
}
Transforms.sigmoid(dot);
});
list.add(future);
}
@ -191,7 +196,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testBroadcastingGenerated() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadcastingGenerated(Nd4jBackend backend) {
int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10);
List<List<Pair<INDArray, String>>> broadCastList = new ArrayList<>(broadcastShape.length);
for (int[] shape : broadcastShape) {
@ -206,7 +213,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
INDArray inputArrBroadcast = val.getFirst();
val destShape = NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7);
INDArray output = inputArrBroadcast
.broadcast(NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7));
.broadcast(NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7));
assertArrayEquals(destShape, output.shape());
}
}
@ -216,7 +223,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testBroadCasting() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadCasting(Nd4jBackend backend) {
INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE);
INDArray ret = first.broadcast(3, 4);
INDArray testRet = Nd4j.create(new double[][] {{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}});
@ -229,14 +238,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testOneTensor() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOneTensor(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1);
INDArray matrixToBroadcast = Nd4j.ones(1, 1);
assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr);
}
@Test
public void testSortWithIndicesDescending() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortWithIndicesDescending(Nd4jBackend backend) {
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
//indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, false);
@ -247,7 +260,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testSortDeadlock() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortDeadlock(Nd4jBackend backend) {
val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768);
val sorted = Nd4j.sort(toSort.dup(), 1, false);
@ -255,7 +270,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testSortWithIndices() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSortWithIndices(Nd4jBackend backend) {
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
//indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, true);
@ -266,14 +283,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testNd4jSortScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jSortScalar(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1);
INDArray sorted = Nd4j.sort(linspace, 1, false);
// System.out.println(sorted);
}
@Test
public void testSwapAxesFortranOrder() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSwapAxesFortranOrder(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE);
for (int i = 0; i < n.slices(); i++) {
INDArray nSlice = n.slice(i);
@ -292,7 +313,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testDimShuffle() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDimShuffle(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false});
assertTrue(Arrays.equals(new long[] {2, 1, 2}, twoOneTwo.shape()));
@ -303,7 +326,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testGetVsGetScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetVsGetScalar(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
float element = a.getFloat(0, 1);
double element2 = a.getDouble(0, 1);
@ -316,7 +341,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testDivide() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDivide(Nd4jBackend backend) {
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
INDArray div = two.div(two);
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage());
@ -330,7 +357,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testSigmoid() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSigmoid(Nd4jBackend backend) {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
INDArray sigmoid = Transforms.sigmoid(n, false);
@ -339,7 +368,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testNeg() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNeg(Nd4jBackend backend) {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
INDArray neg = Transforms.neg(n);
@ -349,7 +380,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testCosineSim() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCosineSim(Nd4jBackend backend) {
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
double sim = Transforms.cosineSim(vec1, vec2);
@ -364,7 +397,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testExp() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testExp(Nd4jBackend backend) {
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f});
INDArray exped = Transforms.exp(n);
@ -374,7 +409,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testScalar(Nd4jBackend backend) {
INDArray a = Nd4j.scalar(1.0f);
assertEquals(true, a.isScalar());
@ -386,7 +423,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testWrap() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testWrap(Nd4jBackend backend) {
int[] shape = {2, 4};
INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]);
INDArray n = d;
@ -411,7 +450,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testGetRowFortran() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowFortran(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2});
INDArray column = Nd4j.create(new float[] {1, 3});
INDArray column2 = Nd4j.create(new float[] {2, 4});
@ -424,7 +465,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testGetColumnFortran() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnFortran(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
INDArray column = Nd4j.create(new double[] {1, 2});
INDArray column2 = Nd4j.create(new double[] {3, 4});
@ -438,7 +481,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testGetColumns() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumns(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE);
// log.info("Original: {}", matrix);
INDArray matrixGet = matrix.getColumns(1, 2);
@ -452,7 +497,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testVectorInit() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorInit(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data();
INDArray arr = Nd4j.create(data, new long[] {1, 4});
assertEquals(true, arr.isRowVector());
@ -465,7 +512,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testAssignOffset() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssignOffset(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(5, 5);
INDArray row = arr.slice(1);
row.assign(1);
@ -473,7 +522,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testColumns() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumns(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE);
INDArray column = Nd4j.create(new double[] {1, 2, 3});
arr.putColumn(0, column);
@ -511,7 +562,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testPutRow() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRow(Nd4jBackend backend) {
INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray n = d.dup();
@ -570,7 +623,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testInplaceTranspose() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInplaceTranspose(Nd4jBackend backend) {
INDArray test = Nd4j.rand(3, 4);
INDArray orig = test.dup();
INDArray transposei = test.transposei();
@ -585,7 +640,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testMmulF() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulF(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data();
INDArray n = Nd4j.create(data, new long[] {1, 10});
@ -603,7 +660,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testRowsColumns() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowsColumns(Nd4jBackend backend) {
DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data();
INDArray rows = Nd4j.create(data, new long[] {2, 3});
assertEquals(2, rows.rows());
@ -619,7 +678,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testTranspose() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTranspose(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4});
INDArray transpose = n.transpose();
assertEquals(n.length(), transpose.length());
@ -647,7 +708,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testAddMatrix() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddMatrix(Nd4jBackend backend) {
INDArray five = Nd4j.ones(5);
five.addi(five.dup());
INDArray twos = Nd4j.valueArrayOf(5, 2);
@ -658,7 +721,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testMMul() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMMul(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
@ -669,7 +734,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testPutSlice() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutSlice(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3);
INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3);
Nd4j.exec(new PrintVariable(newSlice));
@ -680,7 +747,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testRowVectorMultipleIndices() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
linear.putScalar(new long[] {0, 1}, 1);
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage());
@ -689,7 +758,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testDim1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDim1(Nd4jBackend backend) {
INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1);
INDArray same = sum.dup();
assertEquals(same.sum(1), sum.reshape(2));
@ -697,7 +768,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testEps() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEps(Nd4jBackend backend) {
val ones = Nd4j.ones(5);
val res = Nd4j.createUninitialized(DataType.BOOL, 5);
assertTrue(Nd4j.getExecutioner().exec(new Eps(ones, ones, res)).all());
@ -705,7 +778,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testLogDouble() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testLogDouble(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE);
INDArray log = Transforms.log(linspace);
INDArray assertion = Nd4j.create(new double[] {0, 0.6931471805599453, 1.0986122886681098, 1.3862943611198906, 1.6094379124341005, 1.791759469228055});
@ -713,28 +788,36 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testVectorSum() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum(Nd4jBackend backend) {
INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
}
@Test
public void testVectorSum2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum2(Nd4jBackend backend) {
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
}
@Test
public void testVectorSum3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorSum3(Nd4jBackend backend) {
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals(lin, lin2);
}
@Test
public void testSmallSum() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSmallSum(Nd4jBackend backend) {
INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007});
base.addi(1e-12);
INDArray assertion = Nd4j.create(new double[] {5.84333433, 3.054001});
@ -745,7 +828,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testPermute() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPermute(Nd4jBackend backend) {
INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4});
INDArray transpose = n.transpose();
INDArray permute = n.permute(1, 0);
@ -774,7 +859,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testAppendBias() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAppendBias(Nd4jBackend backend) {
INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray test = Nd4j.appendBias(rand);
INDArray assertion = Nd4j.toFlattened(rand, Nd4j.scalar(DataType.DOUBLE, 1.0)).reshape(-1, 1);
@ -782,7 +869,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testRand() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRand(Nd4jBackend backend) {
INDArray rand = Nd4j.randn(5, 5);
Nd4j.getDistributions().createUniform(0.4, 4).sample(5);
Nd4j.getDistributions().createNormal(1, 5).sample(10);
@ -794,7 +883,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testIdentity() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIdentity(Nd4jBackend backend) {
INDArray eye = Nd4j.eye(5);
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
eye = Nd4j.eye(5);
@ -805,7 +896,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testColumnVectorOpsFortran() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testColumnVectorOpsFortran(Nd4jBackend backend) {
INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1});
twoByTwo.addiColumnVector(toAdd);
@ -816,7 +909,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testRSubi() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRSubi(Nd4jBackend backend) {
INDArray n2 = Nd4j.ones(2);
INDArray n2Assertion = Nd4j.zeros(2);
INDArray nRsubi = n2.rsubi(1);
@ -826,7 +921,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testAssign() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
vector.assign(1);
assertEquals(Nd4j.ones(5).castTo(DataType.DOUBLE), vector);
@ -843,7 +940,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testAddScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddScalar(Nd4jBackend backend) {
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
INDArray rdiv = div.add(1);
INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 5.0);
@ -851,7 +950,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testRdivScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRdivScalar(Nd4jBackend backend) {
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
INDArray rdiv = div.rdiv(1);
INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 0.25);
@ -859,7 +960,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testRDivi() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRDivi(Nd4jBackend backend) {
INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0);
INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5);
INDArray nRsubi = n2.rdivi(2);
@ -869,7 +972,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testNumVectorsAlongDimension() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNumVectorsAlongDimension(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2);
assertEquals(12, arr.vectorsAlongDimension(2));
}
@ -877,7 +982,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testBroadCast() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBroadCast(Nd4jBackend backend) {
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
INDArray broadCasted = n.broadcast(5, 4);
for (int i = 0; i < broadCasted.rows(); i++) {
@ -899,7 +1006,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testMatrix() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMatrix(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2});
INDArray row = arr.getRow(0);
@ -909,7 +1018,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testPutRowGetRowOrdering() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowGetRowOrdering(Nd4jBackend backend) {
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray put = Nd4j.create(new double[] {5, 6});
row1.putRow(1, put);
@ -931,7 +1042,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testSumWithRow1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSumWithRow1(Nd4jBackend backend) {
//Works:
INDArray array2d = Nd4j.ones(1, 10);
array2d.sum(0); //OK
@ -962,7 +1075,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testSumWithRow2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSumWithRow2(Nd4jBackend backend) {
//All sums in this method execute without exceptions.
INDArray array3d = Nd4j.ones(2, 10, 10);
array3d.sum(0);
@ -985,7 +1100,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testPutRowFortran() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowFortran(Nd4jBackend backend) {
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE);
INDArray put = Nd4j.create(new double[] {5, 6});
row1.putRow(1, put);
@ -998,7 +1115,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testElementWiseOps() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testElementWiseOps(Nd4jBackend backend) {
INDArray n1 = Nd4j.scalar(1);
INDArray n2 = Nd4j.scalar(2);
INDArray nClone = n1.add(n2);
@ -1021,7 +1140,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
public void testRollAxis() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRollAxis(Nd4jBackend backend) {
INDArray toRoll = Nd4j.ones(3, 4, 5, 6);
assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape());
val shape = Nd4j.rollAxis(toRoll, 3).shape();
@ -1030,20 +1151,22 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
@Test
@Disabled
public void testTensorDot() {
public void testTensorDot(Nd4jBackend backend) {
INDArray oneThroughSixty = Nd4j.arange(60).reshape('f', 3, 4, 5).castTo(DataType.DOUBLE);
INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape('f', 4, 3, 2).castTo(DataType.DOUBLE);
INDArray result = Nd4j.tensorMmul(oneThroughSixty, oneThroughTwentyFour, new int[][] {{1, 0}, {0, 1}});
assertArrayEquals(new long[] {5, 2}, result.shape());
INDArray assertion = Nd4j.create(new double[][] {{440., 1232.}, {1232., 3752.}, {2024., 6272.}, {2816., 8792.},
{3608., 11312.}});
{3608., 11312.}});
assertEquals(assertion, result);
}
@Test
public void testNegativeShape() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeShape(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
INDArray reshaped = linspace.reshape(-1, 2);
assertArrayEquals(new long[] {2, 2}, reshaped.shape());
@ -1055,7 +1178,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testGetColumnGetRow() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnGetRow(Nd4jBackend backend) {
INDArray row = Nd4j.ones(1, 5);
for (int i = 0; i < 5; i++) {
INDArray col = row.getColumn(i);
@ -1070,7 +1195,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testDupAndDupWithOrder() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDupAndDupWithOrder(Nd4jBackend backend) {
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
int count = 0;
for (Pair<INDArray, String> pair : testInputs) {
@ -1092,7 +1219,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
}
@Test
public void testToOffsetZeroCopy() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToOffsetZeroCopy(Nd4jBackend backend) {
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
int cnt = 0;

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,18 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class Nd4jTestsComparisonC extends BaseNd4jTest {
public class Nd4jTestsComparisonC extends BaseNd4jTestWithBackends {
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonC.class);
public static final int SEED = 123;
DataType initialType;
DataType initialType = Nd4j.dataType();
public Nd4jTestsComparisonC(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach
@ -73,7 +70,9 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest {
@Test
public void testGemmWithOpsCommonsMath() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
@ -140,13 +139,13 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest {
private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first,
Pair<INDArray, String> second) {
Pair<INDArray, String> second) {
return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")";
}
private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha,
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
return i + "," + j + " - gemm(tA=" + transposeA + ",tB=" + transposeB + ",alpha=" + alpha + ",beta=" + beta
+ "). A=" + first.getSecond() + ", B=" + second.getSecond();
+ "). A=" + first.getSecond() + ", B=" + second.getSecond();
}
}

View File

@ -25,8 +25,9 @@ import org.apache.commons.math3.linear.RealMatrix;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -43,18 +44,14 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
@RunWith(Parameterized.class)
public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonFortran.class);
public static final int SEED = 123;
DataType initialType;
DataType initialType = Nd4j.dataType();
public Nd4jTestsComparisonFortran(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@BeforeEach
@ -75,7 +72,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testCrash() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCrash(Nd4jBackend backend) {
INDArray array3d = Nd4j.ones(1, 10, 10);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 1);
@ -85,7 +84,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testMmulWithOpsCommonsMath() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMmulWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
@ -100,7 +101,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testGemmWithOpsCommonsMath() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
@ -156,7 +159,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testGemvApacheCommons() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemvApacheCommons(Nd4jBackend backend) {
int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8};
int[] colsArr = new int[] {2, 1, 10, 2, 1, 10};
@ -197,7 +202,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
assertArrayEquals(new long[] {rows, 1}, gemv.shape());
assertArrayEquals(new int[] {rows, 1},
new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()});
new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()});
//Check entries:
for (int r = 0; r < rows; r++) {
@ -211,7 +216,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testAddSubtractWithOpsCommonsMath() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddSubtractWithOpsCommonsMath(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
for (int i = 0; i < first.size(); i++) {
@ -229,7 +236,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
@Test
public void testMulDivOnCheckUtilMatrices() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMulDivOnCheckUtilMatrices(Nd4jBackend backend) {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
for (int i = 0; i < first.size(); i++) {
@ -245,13 +254,13 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
}
private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first,
Pair<INDArray, String> second) {
Pair<INDArray, String> second) {
return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")";
}
private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha,
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
return i + "," + j + " - gemm(tA=" + transposeA + ",tB= " + transposeB + ",alpha=" + alpha + ",beta= " + beta
+ "). A=" + first.getSecond() + ", B=" + second.getSecond();
+ "). A=" + first.getSecond() + ", B=" + second.getSecond();
}
}

View File

@ -23,8 +23,9 @@ package org.nd4j.linalg;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -36,18 +37,15 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j
@RunWith(Parameterized.class)
public class Nd4jTestsF extends BaseNd4jTest {
DataType initialType;
public class Nd4jTestsF extends BaseNd4jTestWithBackends {
public Nd4jTestsF(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
DataType initialType = Nd4j.dataType();
@Test
public void testConcat3D_Vstack_F() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcat3D_Vstack_F(Nd4jBackend backend) {
//Nd4j.getExecutioner().enableVerboseMode(true);
//Nd4j.getExecutioner().enableDebugMode(true);
@ -79,7 +77,9 @@ public class Nd4jTestsF extends BaseNd4jTest {
@Test
public void testSlice_1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlice_1(Nd4jBackend backend) {
val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1);
val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1});
val exp1 = Nd4j.create(new double[]{2, 4}, new int[] {2, 1});

View File

@ -22,8 +22,9 @@ package org.nd4j.linalg;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -34,15 +35,13 @@ import java.util.*;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.*;
@RunWith(Parameterized.class)
public class ShufflesTests extends BaseNd4jTest {
public ShufflesTests(Nd4jBackend backend) {
super(backend);
}
public class ShufflesTests extends BaseNd4jTestWithBackends {
@Test
public void testSimpleShuffle1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle1(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(10, 10);
for (int x = 0; x < 10; x++) {
array.getRow(x).assign(x);
@ -64,7 +63,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testSimpleShuffle2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle2(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(10, 10);
for (int x = 0; x < 10; x++) {
array.getColumn(x).assign(x);
@ -79,7 +80,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testSimpleShuffle3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSimpleShuffle3(Nd4jBackend backend) {
INDArray array = Nd4j.zeros(11, 10);
for (int x = 0; x < 11; x++) {
array.getRow(x).assign(x);
@ -95,7 +98,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testSymmetricShuffle1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle1(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10);
INDArray labels = Nd4j.zeros(10, 3);
for (int x = 0; x < 10; x++) {
@ -133,7 +138,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testSymmetricShuffle2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle2(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10, 20);
INDArray labels = Nd4j.zeros(10, 10, 3);
@ -171,7 +178,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testSymmetricShuffle3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSymmetricShuffle3(Nd4jBackend backend) {
INDArray features = Nd4j.zeros(10, 10, 20);
INDArray featuresMask = Nd4j.zeros(10, 20);
INDArray labels = Nd4j.zeros(10, 10, 3);
@ -236,7 +245,9 @@ public class ShufflesTests extends BaseNd4jTest {
* @throws Exception
*/
@Test
public void testHalfVectors1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testHalfVectors1(Nd4jBackend backend) {
int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20);
int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20);
@ -257,7 +268,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testInterleavedVector1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterleavedVector1(Nd4jBackend backend) {
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20);
int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20);
@ -278,7 +291,9 @@ public class ShufflesTests extends BaseNd4jTest {
}
@Test
public void testInterleavedVector3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterleavedVector3(Nd4jBackend backend) {
for (int e = 0; e < 1000; e++) {
int length = e + 256; //RandomUtils.nextInt(121, 2073);
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(System.currentTimeMillis()), length);

View File

@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.eigen.Eigen;
@ -35,16 +36,11 @@ import org.nd4j.common.util.ArrayUtil;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
@Slf4j
public class TestEigen extends BaseNd4jTest {
public class TestEigen extends BaseNd4jTestWithBackends {
protected DataType initialType;
public TestEigen(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
protected DataType initialType = Nd4j.dataType();
@BeforeEach
public void before() {
@ -59,7 +55,9 @@ public class TestEigen extends BaseNd4jTest {
// test of functions added by Luke Czapla
// Compares solution of A x = L x to solution to A x = L B x when it is simple
@Test
public void test2Syev() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2Syev(Nd4jBackend backend) {
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
Nd4j.setDefaultDataTypes(dt, dt);
@ -78,7 +76,9 @@ public class TestEigen extends BaseNd4jTest {
}
@Test
public void testSyev() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSyev(Nd4jBackend backend) {
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
//log.info("Datatype: {}", dt);
Nd4j.setDefaultDataTypes(dt, dt);

View File

@ -24,23 +24,23 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.util.ArrayUtil;
@RunWith(Parameterized.class)
@Slf4j
public class ToStringTest extends BaseNd4jTest {
public ToStringTest(Nd4jBackend backend) {
super(backend);
}
public class ToStringTest extends BaseNd4jTestWithBackends {
@Test
public void testToString() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToString(Nd4jBackend backend) throws Exception {
assertEquals("[ 1, 2, 3]",
Nd4j.createFromArray(1, 2, 3).toString());
@ -58,6 +58,8 @@ public class ToStringTest extends BaseNd4jTest {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToStringScalars(){
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.activations;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.activations.impl.ActivationCube;
import org.nd4j.linalg.activations.impl.ActivationELU;
import org.nd4j.linalg.activations.impl.ActivationGELU;
@ -55,12 +56,9 @@ import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestActivation extends BaseNd4jTest {
public TestActivation(Nd4jBackend backend) {
super(backend);
}
public class TestActivation extends BaseNd4jTestWithBackends {
@Override
public char ordering() {
@ -79,7 +77,9 @@ public class TestActivation extends BaseNd4jTest {
}
@Test
public void testRelu(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRelu(Nd4jBackend backend){
Double[] max = {null, 6.0, 2.5, 5.0};
Double[] threshold = {0.0, 0.0, 0.75, 0.2};
@ -131,30 +131,32 @@ public class TestActivation extends BaseNd4jTest {
}
@Test
public void testJson() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testJson(Nd4jBackend backend) throws Exception {
IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25),
new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(),
new ActivationLReLU(0.25), new ActivationRationalTanh(), new ActivationReLU(),
new ActivationRReLU(0.25, 0.5), new ActivationSigmoid(), new ActivationSoftmax(),
new ActivationSoftPlus(), new ActivationSoftSign(), new ActivationTanH(), new ActivationGELU(), new ActivationGELU(true)};
new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(),
new ActivationLReLU(0.25), new ActivationRationalTanh(), new ActivationReLU(),
new ActivationRReLU(0.25, 0.5), new ActivationSigmoid(), new ActivationSoftmax(),
new ActivationSoftPlus(), new ActivationSoftSign(), new ActivationTanH(), new ActivationGELU(), new ActivationGELU(true)};
String[][] expectedFields = new String[][] {{"@class"}, //Cube
{"@class", "alpha"}, //ELU
{"@class"}, //Hard sigmoid
{"@class"}, //Hard TanH
{"@class"}, //Identity
{"@class", "alpha"}, //Leaky Relu
{"@class"}, //rational tanh
{"@class", "max", "negativeSlope", "threshold"}, //relu
{"@class", "l", "u"}, //rrelu
{"@class"}, //sigmoid
{"@class"}, //Softmax
{"@class"}, //Softplus
{"@class"}, //Softsign
{"@class"}, //Tanh
{"@class", "precise"}, //GELU
{"@class", "precise"} //GELU precise
{"@class", "alpha"}, //ELU
{"@class"}, //Hard sigmoid
{"@class"}, //Hard TanH
{"@class"}, //Identity
{"@class", "alpha"}, //Leaky Relu
{"@class"}, //rational tanh
{"@class", "max", "negativeSlope", "threshold"}, //relu
{"@class", "l", "u"}, //rrelu
{"@class"}, //sigmoid
{"@class"}, //Softmax
{"@class"}, //Softplus
{"@class"}, //Softsign
{"@class"}, //Tanh
{"@class", "precise"}, //GELU
{"@class", "precise"} //GELU precise
};
@ -172,7 +174,7 @@ public class TestActivation extends BaseNd4jTest {
String[] expFields = expectedFields[i];
String msg = activations[i].toString() + "\tExpected fields: " + Arrays.toString(expFields)
+ "\tActual fields: " + actualFieldsByName;
+ "\tActual fields: " + actualFieldsByName;
assertEquals(expFields.length, actualFieldsByName.size(),msg);
for (String s : expFields) {

View File

@ -20,21 +20,20 @@
package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.factory.Environment;
import org.nd4j.linalg.factory.Nd4j;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertFalse;
public class TestBackend extends BaseNd4jTest {
public class TestBackend extends BaseNd4jTestWithBackends {
public TestBackend(Nd4jBackend backend) {
super(backend);
}
@Test
public void TestBuildInfo(){
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBuildInfo(Nd4jBackend backend){
System.out.println("Backend build info: " + backend.buildInfo());
}
}

View File

@ -20,26 +20,27 @@
package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Environment;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertFalse;
public class TestEnvironment extends BaseNd4jTest {
public class TestEnvironment extends BaseNd4jTestWithBackends {
public TestEnvironment(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void testEnvironment(){
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEnvironment(Nd4jBackend backend){
Environment e = Nd4j.getEnvironment();
System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion());
System.out.println("CPU: " + e.isCPU());

View File

@ -26,7 +26,9 @@ import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -40,16 +42,12 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class TestNDArrayCreation extends BaseNd4jTest {
public TestNDArrayCreation(Nd4jBackend backend) {
super(backend);
}
public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
@Test
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public void testBufferCreation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBufferCreation(Nd4jBackend backend) {
DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2});
Pointer pointer = dataBuffer.pointer();
FloatPointer floatPointer = new FloatPointer(pointer);
@ -69,6 +67,8 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test
@Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateNpy() throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile());
assertEquals(2, arrCreate.size(0));
@ -82,7 +82,9 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test
@Disabled
public void testCreateNpz() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateNpz(Nd4jBackend backend) throws Exception {
Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile());
assertEquals(true, map.containsKey("x"));
assertEquals(true, map.containsKey("y"));
@ -100,8 +102,7 @@ public class TestNDArrayCreation extends BaseNd4jTest {
}
@Test
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public void testCreateNpy3() throws Exception {
public void testCreateNpy3(Nd4jBackend backend) throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
assertEquals(8, arrCreate.length());
assertEquals(3, arrCreate.rank());
@ -113,7 +114,7 @@ public class TestNDArrayCreation extends BaseNd4jTest {
@Test
@Disabled // this is endless test
public void testEndlessAllocation() {
public void testEndlessAllocation(Nd4jBackend backend) {
Nd4j.getEnvironment().setMaxSpecialMemory(1);
while (true) {
val arr = Nd4j.createUninitialized(DataType.FLOAT, 100000000);

View File

@ -21,24 +21,23 @@
package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
public class TestNDArrayCreationUtil extends BaseNd4jTest {
public class TestNDArrayCreationUtil extends BaseNd4jTestWithBackends {
public TestNDArrayCreationUtil(Nd4jBackend backend) {
super(backend);
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testShapes() {
long[] shape2d = {2, 3};

View File

@ -21,20 +21,21 @@
package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
public class TestNamespaces extends BaseNd4jTest {
public class TestNamespaces extends BaseNd4jTestWithBackends {
public TestNamespaces(Nd4jBackend backend) {
super(backend);
}
@Test
public void testBitwiseSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBitwiseSimple(Nd4jBackend backend){
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
@ -50,7 +51,9 @@ public class TestNamespaces extends BaseNd4jTest {
}
@Test
public void testMathSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testMathSimple(Nd4jBackend backend) {
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1);
INDArray abs = Nd4j.math.abs(x);
// System.out.println(x);
@ -65,7 +68,9 @@ public class TestNamespaces extends BaseNd4jTest {
}
@Test
public void testRandomSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomSimple(Nd4jBackend backend){
INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10);
// System.out.println(normal);
INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10);
@ -73,7 +78,9 @@ public class TestNamespaces extends BaseNd4jTest {
}
@Test
public void testNeuralNetworkSimple(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNeuralNetworkSimple(Nd4jBackend backend){
INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10));
// System.out.println(out);
INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -31,15 +32,14 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class LapackTest extends BaseNd4jTest {
public LapackTest(Nd4jBackend backend) {
super(backend);
}
public class LapackTest extends BaseNd4jTestWithBackends {
@Test
public void testQRSquare() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testQRSquare(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9});
A = A.reshape('c', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -57,7 +57,9 @@ public class LapackTest extends BaseNd4jTest {
}
@Test
public void testQRRect() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testQRRect(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
A = A.reshape('f', 4, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -75,7 +77,9 @@ public class LapackTest extends BaseNd4jTest {
}
@Test
public void testCholeskyL() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCholeskyL(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,});
A = A.reshape('c', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape());
@ -92,7 +96,9 @@ public class LapackTest extends BaseNd4jTest {
}
@Test
public void testCholeskyU() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCholeskyU(Nd4jBackend backend) {
INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,});
A = A.reshape('f', 3, 3);
INDArray O = Nd4j.create(A.dataType(), A.shape());

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.blas;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -35,14 +36,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class Level1Test extends BaseNd4jTest {
public Level1Test(Nd4jBackend backend) {
super(backend);
}
public class Level1Test extends BaseNd4jTestWithBackends {
@Test
public void testDot() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDot(Nd4jBackend backend) {
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1);
@ -55,7 +55,9 @@ public class Level1Test extends BaseNd4jTest {
}
@Test
public void testAxpy() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAxpy(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray row = matrix.getRow(1);
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
@ -64,7 +66,9 @@ public class Level1Test extends BaseNd4jTest {
}
@Test
public void testAxpy2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAxpy2(Nd4jBackend backend) {
val rowX = Nd4j.create(new double[]{1, 2, 3, 4});
val rowY = Nd4j.create(new double[]{1, 2, 3, 4});
val exp = Nd4j.create(new double[]{3, 6, 9, 12});

View File

@ -21,23 +21,23 @@
package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class Level2Test extends BaseNd4jTest {
public Level2Test(Nd4jBackend backend) {
super(backend);
}
public class Level2Test extends BaseNd4jTestWithBackends {
@Test
public void testGemv1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv1(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -51,7 +51,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv2(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -65,7 +67,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv3(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -79,7 +83,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv4() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv4(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -93,7 +99,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv5() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv5(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -109,7 +117,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv6() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv6(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -125,7 +135,9 @@ public class Level2Test extends BaseNd4jTest {
}
@Test
public void testGemv7() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemv7(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);

View File

@ -21,23 +21,23 @@
package org.nd4j.linalg.api.blas;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class Level3Test extends BaseNd4jTest {
public Level3Test(Nd4jBackend backend) {
super(backend);
}
public class Level3Test extends BaseNd4jTestWithBackends {
@Test
public void testGemm1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm1(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
@ -47,7 +47,9 @@ public class Level3Test extends BaseNd4jTest {
}
@Test
public void testGemm2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm2(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100);
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
@ -57,7 +59,9 @@ public class Level3Test extends BaseNd4jTest {
}
@Test
public void testGemm3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm3(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
@ -75,7 +79,9 @@ public class Level3Test extends BaseNd4jTest {
}
@Test
public void testGemm4() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm4(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);
@ -92,7 +98,9 @@ public class Level3Test extends BaseNd4jTest {
}
@Test
public void testGemm5() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm5(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
@ -106,7 +114,9 @@ public class Level3Test extends BaseNd4jTest {
}
@Test
public void testGemm6() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm6(Nd4jBackend backend) {
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.blas.params;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,16 +34,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class ParamsTestsF extends BaseNd4jTest {
public ParamsTestsF(Nd4jBackend backend) {
super(backend);
}
public class ParamsTestsF extends BaseNd4jTestWithBackends {
@Test
public void testGemm() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGemm (Nd4jBackend backend) {
INDArray a = Nd4j.create(2, 2);
INDArray b = Nd4j.create(2, 3);
INDArray c = Nd4j.create(2, 3);

View File

@ -25,9 +25,10 @@ import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -45,16 +46,15 @@ import java.nio.ByteOrder;
import static org.junit.jupiter.api.Assertions.*;
@Slf4j
@RunWith(Parameterized.class)
public class DataBufferTests extends BaseNd4jTest {
public DataBufferTests(Nd4jBackend backend) {
super(backend);
}
public class DataBufferTests extends BaseNd4jTestWithBackends {
@Test
@Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657")
public void testNoArgCreateBufferFromArray() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNoArgCreateBufferFromArray(Nd4jBackend backend) {
//Tests here:
//1. Create from JVM array
@ -280,7 +280,9 @@ public class DataBufferTests extends BaseNd4jTest {
@Test
public void testCreateTypedBuffer() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testCreateTypedBuffer(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
@ -350,7 +352,9 @@ public class DataBufferTests extends BaseNd4jTest {
}
@Test
public void testAsBytes() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAsBytes(Nd4jBackend backend) {
INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1);
for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16,
@ -404,7 +408,9 @@ public class DataBufferTests extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testEnsureLocation(){
//https://github.com/eclipse/deeplearning4j/issues/8783
Nd4j.create(1);

View File

@ -23,9 +23,10 @@ package org.nd4j.linalg.api.buffer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
@ -33,13 +34,10 @@ import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.jupiter.api.Assertions.assertThrows;
@RunWith(Parameterized.class)
public class DataTypeValidationTests extends BaseNd4jTest {
DataType initialType;
public DataTypeValidationTests(Nd4jBackend backend) {
super(backend);
}
public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
@BeforeEach
public void setUp() {
@ -48,7 +46,7 @@ public class DataTypeValidationTests extends BaseNd4jTest {
}
@AfterEach
public void shutUp() {
public void reset() {
Nd4j.setDataType(initialType);
}
@ -73,7 +71,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level1 blas
*/
@Test()
public void testBlasValidation1() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
INDArray x = Nd4j.create(10);
@ -90,7 +90,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level2 blas
*/
@Test()
public void testBlasValidation2() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation2(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> {
INDArray a = Nd4j.create(100, 10);
INDArray x = Nd4j.create(100);
@ -108,7 +110,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
* Testing level3 blas
*/
@Test()
public void testBlasValidation3() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBlasValidation3(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
INDArray x = Nd4j.create(100, 100);

View File

@ -26,9 +26,10 @@ import org.bytedeco.javacpp.indexer.Indexer;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -54,34 +55,31 @@ import static org.junit.jupiter.api.Assertions.*;
*
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public class DoubleDataBufferTest extends BaseNd4jTest {
public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
DataType initialType;
public DoubleDataBufferTest(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
DataType initialType = Nd4j.dataType();
@BeforeEach
public void before() {
public void before(Nd4jBackend backend) {
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
}
@AfterEach
public void after() {
public void after(Nd4jBackend backend) {
DataTypeUtil.setDTypeForContext(initialType);
}
@Test
public void testPointerCreation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointerCreation(Nd4jBackend backend) {
DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4);
Indexer indexer = DoubleIndexer.create(floatPointer);
DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.DOUBLE, 4, indexer);
@ -89,8 +87,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
assertArrayEquals(other.asDouble(), buffer.asDouble(), 0.001);
}
@Test
public void testGetSet() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetSet(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
double[] d2 = d.asDouble();
@ -100,10 +100,12 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization2() throws Exception {
INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10),
// Nd4j.ones(5,10).getRow(2)
// Nd4j.ones(5,10).getRow(2)
};
for (INDArray a : arr) {
@ -128,7 +130,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile();
DataBuffer buf = Nd4j.createBuffer(5);
@ -150,8 +154,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
public void testDup() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDup(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
DataBuffer d2 = d.dup();
@ -160,8 +166,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
@Test
public void testPut() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPut(Nd4jBackend backend) {
double[] d1 = new double[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
d.put(0, 0.0);
@ -171,8 +179,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
public void testGetRange() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
double[] get = buffer.getDoublesAt(0, 3);
double[] data = new double[] {1, 2, 3};
@ -186,8 +196,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
public void testGetOffsetRange() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetOffsetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
double[] get = buffer.getDoublesAt(1, 3);
double[] data = new double[] {2, 3, 4};
@ -201,8 +213,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
public void testAssign() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
DataBuffer one = Nd4j.createBuffer(new double[] {1});
DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3});
@ -212,8 +226,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
public void testOffset() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffset(Nd4jBackend backend) {
DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2);
assertEquals(2, create.length());
assertEquals(0, create.offset());
@ -222,8 +238,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
public void testReallocation() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity());
double[] old = buffer.asDouble();
@ -232,10 +250,12 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
assertArrayEquals(old, Arrays.copyOf(buffer.asDouble(), 4), 1e-1);
}
@Test
public void testReallocationWorkspace() {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4});
@ -249,7 +269,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddressPointer(){
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
return;

View File

@ -27,7 +27,9 @@ import org.bytedeco.javacpp.indexer.Indexer;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
@ -54,14 +56,9 @@ import static org.junit.jupiter.api.Assertions.*;
* @author Adam Gibson
*/
@Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public class FloatDataBufferTest extends BaseNd4jTest {
public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
DataType initialType;
public FloatDataBufferTest(Nd4jBackend backend) {
super(backend);
initialType = Nd4j.dataType();
}
DataType initialType = Nd4j.dataType();
@BeforeEach
public void before() {
@ -76,7 +73,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test
public void testPointerCreation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointerCreation(Nd4jBackend backend) {
FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4);
Indexer indexer = FloatIndexer.create(floatPointer);
DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.FLOAT, 4, indexer);
@ -85,7 +84,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testGetSet() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetSet(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
float[] d2 = d.asFloat();
@ -96,7 +97,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test
public void testSerialization(@TempDir Path tempDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerialization(@TempDir Path tempDir,Nd4jBackend backend) throws Exception {
File dir = tempDir.toFile();
DataBuffer buf = Nd4j.createBuffer(5);
String fileName = "buf.ser";
@ -117,7 +120,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testDup() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testDup(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
DataBuffer d2 = d.dup();
@ -125,7 +130,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testToNio() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testToNio(Nd4jBackend backend) {
DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT);
assertEquals(4, buff.length());
if (buff.allocationMode() == DataBuffer.AllocationMode.HEAP)
@ -137,7 +144,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testPut() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPut(Nd4jBackend backend) {
float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
d.put(0, 0.0);
@ -148,7 +157,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test
public void testGetRange() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(0, 3);
float[] data = new float[] {1, 2, 3};
@ -164,7 +175,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test
public void testGetOffsetRange() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetOffsetRange(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(1, 3);
float[] data = new float[] {2, 3, 4};
@ -181,7 +194,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
@Test
public void testAsBytes() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAsBytes(Nd4jBackend backend) {
INDArray arr = Nd4j.create(5);
byte[] d = arr.data().asBytes();
assertEquals(4 * 5, d.length,getFailureMessage());
@ -191,7 +206,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testAssign() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAssign(Nd4jBackend backend) {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
DataBuffer one = Nd4j.createBuffer(new double[] {1});
DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3});
@ -201,7 +218,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testReadWrite() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReadWrite(Nd4jBackend backend) throws Exception {
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
@ -215,7 +234,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testOffset() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffset(Nd4jBackend backend) {
DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2);
assertEquals(2, create.length());
assertEquals(0, create.offset());
@ -225,7 +246,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testReallocation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity());
float[] old = buffer.asFloat();
@ -236,7 +259,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testReallocationWorkspace() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
@ -253,7 +278,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
}
@Test
public void testAddressPointer(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAddressPointer(Nd4jBackend backend){
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
return;
}

View File

@ -23,7 +23,9 @@ package org.nd4j.linalg.api.buffer;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
@ -37,13 +39,12 @@ import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.*;
public class IntDataBufferTests extends BaseNd4jTest {
public class IntDataBufferTests extends BaseNd4jTestWithBackends {
public IntDataBufferTests(Nd4jBackend backend) {
super(backend);
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testBasicSerde1() throws Exception {
@ -82,7 +83,9 @@ public class IntDataBufferTests extends BaseNd4jTest {
*/
@Test
public void testReallocation() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocation(Nd4jBackend backend) {
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity());
buffer.reallocate(6);
@ -94,9 +97,11 @@ public class IntDataBufferTests extends BaseNd4jTest {
}
@Test
public void testReallocationWorkspace() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testReallocationWorkspace(Nd4jBackend backend) {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -37,17 +38,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class IndexingTests extends BaseNd4jTest {
public class IndexingTests extends BaseNd4jTestWithBackends {
public IndexingTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testINDArrayIndexingEqualToRank() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingEqualToRank(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{
{0,1,2},
@ -62,7 +61,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test
public void testINDArrayIndexingLessThanRankSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingLessThanRankSimple(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{
{0},
@ -76,7 +77,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test
public void testINDArrayIndexingLessThanRankFourDimension() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testINDArrayIndexingLessThanRankFourDimension(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
INDArray indexes = Nd4j.create(new double[][]{
{0},{1}
@ -89,7 +92,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testPutSimple() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutSimple(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2);
INDArray indexes = Nd4j.create(new double[][]{
{0},{1}
@ -101,7 +106,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testGetScalar() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetScalar(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(NDArrayIndex.point(1));
assertTrue(d.isScalar());
@ -110,14 +117,18 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testNewAxis() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis(Nd4jBackend backend) {
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
// System.out.println(view);
}
@Test
public void testVectorIndexing() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing(Nd4jBackend backend) {
INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE);
int[] index = new int[] {5, 8, 9};
INDArray columnsTest = x.getColumns(index);
@ -129,7 +140,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testGetRowsColumnsMatrix() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowsColumnsMatrix(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6);
INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}});
@ -147,7 +160,9 @@ public class IndexingTests extends BaseNd4jTest {
@Test
public void testSlicing() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSlicing(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14});
INDArray slice1Test = arange.slice(1);
@ -155,7 +170,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testArangeMul() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArangeMul(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = NDArrayIndex.interval(0, 2);
INDArray get = arange.get(index, index);
@ -167,7 +184,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testGetIndicesVector() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVector(Nd4jBackend backend) {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3});
INDArray result = line.get(NDArrayIndex.point(0), NDArrayIndex.interval(1, 3));
@ -175,7 +194,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void testGetIndicesVectorView() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVectorView(Nd4jBackend backend) {
INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5);
INDArray column = matrix.getColumn(0).reshape(1,5);
INDArray test = Nd4j.create(new double[] {6, 11});
@ -193,7 +214,9 @@ public class IndexingTests extends BaseNd4jTest {
}
@Test
public void test2dGetPoint(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2dGetPoint(Nd4jBackend backend){
INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4);
for( int i=0; i<3; i++ ){
INDArray exp = Nd4j.create(new double[]{i*4+1, i*4+2, i*4+3, i*4+4});
@ -206,7 +229,7 @@ public class IndexingTests extends BaseNd4jTest {
assertEquals(exp, get);
}
for( int i=0; i<4; i++ ){
for( int i = 0; i < 4; i++) {
INDArray exp = Nd4j.create(new double[]{1+i, 5+i, 9+i});
INDArray col = arr.getColumn(i);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(i));

View File

@ -21,10 +21,11 @@
package org.nd4j.linalg.api.indexing;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -49,16 +50,15 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class IndexingTestsC extends BaseNd4jTest {
public class IndexingTestsC extends BaseNd4jTestWithBackends {
public IndexingTestsC(Nd4jBackend backend) {
super(backend);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeBounds() {
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
@ -70,7 +70,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion,get);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis() {
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
@ -79,7 +81,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void broadcastBug() {
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
@ -90,7 +94,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIntervalsIn3D() {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
@ -99,7 +105,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSmallInterval() {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
@ -108,7 +116,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxisAndInterval() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
@ -117,7 +127,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion2, get2);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxisInMiddle() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
@ -126,7 +138,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion2, get2);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testAllWithNewAxis() {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray get = arr.get(newAxis(), all(), point(1));
@ -136,7 +150,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingWithMmul() {
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
@ -147,7 +163,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion, c);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPointPointInterval() {
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
@ -156,7 +174,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion, get);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIntervalLowerBound() {
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
@ -167,7 +187,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetPointRowVector() {
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
@ -177,7 +199,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSpecifiedIndexVector() {
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
@ -194,7 +218,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testPutRowIndexing() {
INDArray arr = Nd4j.ones(1, 10);
INDArray row = Nd4j.create(1, 10);
@ -204,7 +230,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(arr, row);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing2() {
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
INDArray assertion = Nd4j.create(new double[] {2, 4});
@ -219,7 +247,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testOffsetsC() {
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
@ -235,7 +265,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexFor() {
long[] shape = {1, 2};
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
@ -244,7 +276,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetScalar() {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(point(1));
@ -253,7 +287,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testVectorIndexing() {
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
@ -261,14 +297,18 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion, viewTest);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNegativeIndices() {
INDArray test = Nd4j.create(10, 10, 10);
test.putScalar(new int[] {0, 0, -1}, 1.0);
assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber());
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndices2d() {
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
INDArray firstRow = twoByTwo.getRow(0);
@ -286,7 +326,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRow() {
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
@ -303,7 +345,9 @@ public class IndexingTestsC extends BaseNd4jTest {
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetRowEdgeCase() {
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
INDArray get = rowVec.getRow(0); //Returning shape [1,1]
@ -312,7 +356,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(rowVec, get);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetColumnEdgeCase() {
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray get = colVec.getColumn(0); //Returning shape [1,1]
@ -321,7 +367,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(colVec, get);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testConcatColumns() {
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
@ -330,7 +378,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion, concat);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testGetIndicesVector() {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3});
@ -338,7 +388,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(test, result);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testArangeMul() {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = interval(0, 2);
@ -349,7 +401,9 @@ public class IndexingTestsC extends BaseNd4jTest {
assertEquals(assertion, mul);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIndexingThorough(){
long[] fullShape = {3,4,5,6,7};
@ -549,7 +603,9 @@ public class IndexingTestsC extends BaseNd4jTest {
return d;
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void debugging(){
long[] inShape = {3,4};
INDArrayIndex[] indexes = new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(1, 2, 4)};

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.resolve;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -36,15 +37,14 @@ import static org.junit.jupiter.api.Assertions.*;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class NDArrayIndexResolveTests extends BaseNd4jTest {
public NDArrayIndexResolveTests(Nd4jBackend backend) {
super(backend);
}
public class NDArrayIndexResolveTests extends BaseNd4jTestWithBackends {
@Test
public void testResolvePoint() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResolvePoint(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
INDArrayIndex[] test = NDArrayIndex.resolve(arr.shape(), NDArrayIndex.point(1));
INDArrayIndex[] assertion = {NDArrayIndex.point(1), NDArrayIndex.all()};
@ -59,6 +59,8 @@ public class NDArrayIndexResolveTests extends BaseNd4jTest {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testResolvePointVector() {
INDArray arr = Nd4j.linspace(1, 4, 4);
INDArrayIndex[] getPoint = {NDArrayIndex.point(1)};

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.shape;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.Indices;
@ -34,19 +35,15 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class IndexShapeTests extends BaseNd4jTest {
public IndexShapeTests(Nd4jBackend backend) {
super(backend);
}
public class IndexShapeTests extends BaseNd4jTestWithBackends {
private int[] shape = {1, 1, 2, 1, 3, 4, 5, 1};
@Test
public void testSinglePoint() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSinglePoint(Nd4jBackend backend) {
/*
Assumes all indexes are filled out.
Test simple general point case
@ -77,7 +74,9 @@ public class IndexShapeTests extends BaseNd4jTest {
}
@Test
public void testInterval() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testInterval(Nd4jBackend backend) {
int[] basicAssertion = {1, 1, 1, 1, 3, 1, 2, 1};
INDArrayIndex[] basicTest = {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 1),
NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(1, 2),
@ -88,7 +87,9 @@ public class IndexShapeTests extends BaseNd4jTest {
@Test
public void testNewAxis() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis(Nd4jBackend backend) {
//normal prepend
int[] prependAssertion = {1, 1, 1, 1, 2, 1, 3, 4, 5, 1};
INDArrayIndex[] prependTest = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(),

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.indexing.shape;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.Indices;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -33,25 +34,26 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class IndexShapeTests2d extends BaseNd4jTest {
public IndexShapeTests2d(Nd4jBackend backend) {
super(backend);
}
public class IndexShapeTests2d extends BaseNd4jTestWithBackends {
private long[] shape = {3, 2};
@Test
public void test2dCases() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void test2dCases(Nd4jBackend backend) {
assertArrayEquals(new long[] {1, 2}, Indices.shape(shape, NDArrayIndex.point(1)));
assertArrayEquals(new long[] {3, 1},
Indices.shape(shape, NDArrayIndex.all(), NDArrayIndex.point(1)));
}
@Test
public void testNewAxis2d() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNewAxis2d(Nd4jBackend backend) {
assertArrayEquals(new long[] {1, 3, 2}, Indices.shape(shape,
NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all()));
assertArrayEquals(new long[] {3, 1, 2}, Indices.shape(shape,

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.iterator;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,15 +34,14 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class NDIndexIteratorTest extends BaseNd4jTest {
public NDIndexIteratorTest(Nd4jBackend backend) {
super(backend);
}
public class NDIndexIteratorTest extends BaseNd4jTestWithBackends {
@Test
public void testIterate() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testIterate(Nd4jBackend backend) {
val shapeIter = new NdIndexIterator(2, 2);
val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},};

View File

@ -28,9 +28,10 @@ import org.apache.commons.lang3.ArrayUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -45,18 +46,15 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j
@RunWith(Parameterized.class)
public class TestNdArrReadWriteTxt extends BaseNd4jTest {
public TestNdArrReadWriteTxt(Nd4jBackend backend) {
super(backend);
}
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
@Test
public void compareAfterWrite(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
int [] ranksToCheck = new int[] {0,1,2,3,4};
for (int i=0; i<ranksToCheck.length;i++) {
for (int i = 0; i < ranksToCheck.length; i++) {
// log.info("Checking read write arrays with rank " + ranksToCheck[i]);
compareArrays(ranksToCheck[i],ordering(), testDir);
}
@ -84,7 +82,9 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTest {
}
@Test
public void testNd4jReadWriteText(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jReadWriteText(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
int count = 0;

View File

@ -25,9 +25,10 @@ import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.nio.file.Path;
@ -35,17 +36,14 @@ import java.nio.file.Path;
import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays;
@Slf4j
@RunWith(Parameterized.class)
public class TestNdArrReadWriteTxtC extends BaseNd4jTest {
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
public TestNdArrReadWriteTxtC(Nd4jBackend backend) {
super(backend);
}
@Test
public void compareAfterWrite(@TempDir Path testDir) throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
for (int i = 0; i < ranksToCheck.length; i++) {
log.info("Checking read write arrays with rank " + ranksToCheck[i]);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.ndarray;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -32,16 +33,14 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
@RunWith(Parameterized.class)
public class TestSerialization extends BaseNd4jTest {
public TestSerialization(Nd4jBackend backend) {
super(backend);
}
public class TestSerialization extends BaseNd4jTestWithBackends {
@Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -71,7 +70,9 @@ public class TestSerialization extends BaseNd4jTest {
}
@Test
public void testSerializationFullArrayJava() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava(Nd4jBackend backend) throws Exception {
int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -102,7 +103,9 @@ public class TestSerialization extends BaseNd4jTest {
}
@Test
public void testSerializationOnViewsNd4jWriteRead() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
@ -138,7 +141,9 @@ public class TestSerialization extends BaseNd4jTest {
}
@Test
public void testSerializationOnViewsJava() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava(Nd4jBackend backend) throws Exception {
int length = 100;
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);

View File

@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.factory.Nd4j;
@ -39,23 +40,21 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
@RunWith(Parameterized.class)
public class TestSerializationDoubleToFloat extends BaseNd4jTest {
DataType initialType;
public class TestSerializationDoubleToFloat extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
public TestSerializationDoubleToFloat(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@AfterEach
public void after() {
DataTypeUtil.setDTypeForContext(this.initialType);
}
@Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 4;
//WRITE OUT A DOUBLE ARRAY
@ -93,7 +92,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
}
@Test
public void testSerializationFullArrayJava() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava(Nd4jBackend backend) throws Exception {
int length = 100;
Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
@ -123,7 +124,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
}
@Test
public void testSerializationOnViewsNd4jWriteRead() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100;
Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
@ -153,7 +156,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
}
@Test
public void testSerializationOnViewsJava() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava(Nd4jBackend backend) throws Exception {
int length = 100;
Nd4j.create(1);
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);

View File

@ -22,9 +22,10 @@ package org.nd4j.linalg.api.ndarray;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.factory.Nd4j;
@ -37,23 +38,21 @@ import java.io.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@RunWith(Parameterized.class)
public class TestSerializationFloatToDouble extends BaseNd4jTest {
DataType initialType;
public class TestSerializationFloatToDouble extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
public TestSerializationFloatToDouble(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@AfterEach
public void after() {
Nd4j.setDataType(this.initialType);
}
@Test
public void testSerializationFullArrayNd4jWriteRead() throws Exception {
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
int length = 100;
//WRITE OUT A FLOAT ARRAY
@ -85,7 +84,9 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
assertTrue(Transforms.abs(arr1.sub(arr2).div(arr1)).maxNumber().doubleValue() < 0.01);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationFullArrayJava() throws Exception {
int length = 100;
Nd4j.create(1);
@ -116,7 +117,9 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
assertTrue(Transforms.abs(arr1.sub(arr2).div(arr1)).maxNumber().doubleValue() < 0.01);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsNd4jWriteRead() throws Exception {
int length = 100;
Nd4j.create(1);
@ -146,7 +149,9 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
assertTrue(Transforms.abs(sub1.sub(arr2).div(sub1)).maxNumber().doubleValue() < 0.01);
}
@Test
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testSerializationOnViewsJava() throws Exception {
int length = 100;
Nd4j.create(1);

View File

@ -21,9 +21,10 @@
package org.nd4j.linalg.api.rng;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -33,14 +34,13 @@ import static org.junit.jupiter.api.Assertions.*;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class RngTests extends BaseNd4jTest {
public RngTests(Nd4jBackend backend) {
super(backend);
}
public class RngTests extends BaseNd4jTestWithBackends {
@Test
public void testRngConstitency() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRngConstitency(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(123);
INDArray arr = Nd4j.rand(1, 5);
Nd4j.getRandom().setSeed(123);
@ -49,7 +49,9 @@ public class RngTests extends BaseNd4jTest {
}
@Test
public void testRandomWithOrder() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomWithOrder(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
@ -105,7 +107,9 @@ public class RngTests extends BaseNd4jTest {
}
@Test
public void testRandomBinomial() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRandomBinomial(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
//silly tests. Just increasing the usage for randomBinomial to stop compiler warnings.
INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3);

View File

@ -23,9 +23,10 @@ package org.nd4j.linalg.api.string;
import lombok.extern.slf4j.Slf4j;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -35,22 +36,23 @@ import org.nd4j.linalg.string.NDArrayStrings;
* @author Adam Gibson
*/
@Slf4j
@RunWith(Parameterized.class)
public class TestFormatting extends BaseNd4jTest {
public TestFormatting(Nd4jBackend backend) {
super(backend);
}
public class TestFormatting extends BaseNd4jTestWithBackends {
@Test
public void testTwoByTwo() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testTwoByTwo(Nd4jBackend backend) {
INDArray arr = Nd4j.create(2, 2, 2, 2);
System.out.println(new NDArrayStrings().format(arr));
}
@Test
public void testNd4jArrayString() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testNd4jArrayString(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[]{1f, 20000000f, 40.838383f, 3f}, new int[]{2, 2});
@ -71,7 +73,9 @@ public class TestFormatting extends BaseNd4jTest {
}
@Test
public void testRange() {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
public void testRange(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new double[][]{
{-1,0,1,0},
{-0.1, 0.1, -10, 10},

Some files were not shown because too many files have changed in this diff Show More