diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index 475c45142..20167a3a1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -37,8 +37,10 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -47,13 +49,14 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.Arrays; +import java.util.stream.Stream; + import static org.deeplearning4j.nn.conf.ConvolutionMode.Same; import static org.deeplearning4j.nn.conf.ConvolutionMode.Truncate; import static org.junit.jupiter.api.Assertions.*; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.extension.ExtendWith; -@RunWith(Parameterized.class) @DisplayName("Cnn Gradient Check Test") class CNNGradientCheckTest extends BaseDL4JTest { @@ -71,15 +74,10 @@ class CNNGradientCheckTest extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - private CNN2DFormat format; - public CNNGradientCheckTest(CNN2DFormat format) { - this.format = format; - } - @Parameterized.Parameters(name = "{0}") - public static Object[] params() { - return CNN2DFormat.values(); + public static Stream params() { + return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of); } @Override @@ -89,9 +87,11 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Gradient CNNMLN") - void testGradientCNNMLN() { + @ParameterizedTest + @MethodSource("#params") + public void testGradientCNNMLN(CNN2DFormat format) { if (// Only test NCHW due to flat input format... - this.format != CNN2DFormat.NCHW) + format != CNN2DFormat.NCHW) return; // Parameterized test, testing combinations of: // (a) activation function @@ -146,9 +146,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Gradient CNNL 1 L 2 MLN") - void testGradientCNNL1L2MLN() { + void testGradientCNNL1L2MLN(CNN2DFormat format) { if (// Only test NCHW due to flat input format... - this.format != CNN2DFormat.NCHW) + format != CNN2DFormat.NCHW) return; // Parameterized test, testing combinations of: // (a) activation function @@ -245,7 +245,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Cnn With Space To Batch") - void testCnnWithSpaceToBatch() { + @ParameterizedTest + @MethodSource("#params") + public void testCnnWithSpaceToBatch(CNN2DFormat format) { Nd4j.getRandom().setSeed(12345); int nOut = 4; int[] minibatchSizes = { 2, 4 }; @@ -289,7 +291,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Cnn With Upsampling") - void testCnnWithUpsampling() { + @ParameterizedTest + @MethodSource("#params") + void testCnnWithUpsampling(CNN2DFormat format) { Nd4j.getRandom().setSeed(12345); int nOut = 4; int[] minibatchSizes = { 1, 3 }; @@ -323,7 +327,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Cnn With Subsampling") - void testCnnWithSubsampling() { + @ParameterizedTest + @MethodSource("#params") + void testCnnWithSubsampling(CNN2DFormat format) { Nd4j.getRandom().setSeed(12345); int nOut = 4; int[] minibatchSizes = { 1, 3 }; @@ -365,7 +371,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Cnn With Subsampling V 2") - void testCnnWithSubsamplingV2() { + @ParameterizedTest + @MethodSource("#params") + void testCnnWithSubsamplingV2(CNN2DFormat format) { Nd4j.getRandom().setSeed(12345); int nOut = 4; int[] minibatchSizes = { 1, 3 }; @@ -403,7 +411,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Cnn Locally Connected 2 D") - void testCnnLocallyConnected2D() { + @ParameterizedTest + @MethodSource("#params") + void testCnnLocallyConnected2D(CNN2DFormat format) { int nOut = 3; int width = 5; int height = 5; @@ -433,7 +443,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Cnn Multi Layer") - void testCnnMultiLayer() { + @ParameterizedTest + @MethodSource("#params") + void testCnnMultiLayer(CNN2DFormat format) { int nOut = 2; int[] minibatchSizes = { 1, 2, 5 }; int width = 5; @@ -473,7 +485,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Cnn Same Padding Mode") - void testCnnSamePaddingMode() { + @ParameterizedTest + @MethodSource("#params") + void testCnnSamePaddingMode(CNN2DFormat format) { int nOut = 2; int[] minibatchSizes = { 1, 3, 3, 2, 1, 2 }; // Same padding mode: insensitive to exact input size... @@ -507,7 +521,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Cnn Same Padding Mode Strided") - void testCnnSamePaddingModeStrided() { + @ParameterizedTest + @MethodSource("#params") + void testCnnSamePaddingModeStrided(CNN2DFormat format) { int nOut = 2; int[] minibatchSizes = { 1, 3 }; int width = 16; @@ -550,7 +566,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Cnn Zero Padding Layer") - void testCnnZeroPaddingLayer() { + @ParameterizedTest + @MethodSource("#params") + void testCnnZeroPaddingLayer(CNN2DFormat format) { Nd4j.getRandom().setSeed(12345); int nOut = 4; int width = 6; @@ -596,7 +614,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Deconvolution 2 D") - void testDeconvolution2D() { + @ParameterizedTest + @MethodSource("#params") + void testDeconvolution2D(CNN2DFormat format) { int nOut = 2; int[] minibatchSizes = new int[] { 1, 3, 3, 1, 3 }; int[] kernelSizes = new int[] { 1, 1, 1, 3, 3 }; @@ -641,7 +661,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Separable Conv 2 D") - void testSeparableConv2D() { + @ParameterizedTest + @MethodSource("#params") + void testSeparableConv2D(CNN2DFormat format) { int nOut = 2; int width = 6; int height = 6; @@ -686,7 +708,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Cnn Dilated") - void testCnnDilated() { + @ParameterizedTest + @MethodSource("#params") + void testCnnDilated(CNN2DFormat format) { int nOut = 2; int minibatchSize = 2; int width = 8; @@ -736,7 +760,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Cropping 2 D Layer") - void testCropping2DLayer() { + @ParameterizedTest + @MethodSource("#params") + void testCropping2DLayer(CNN2DFormat format) { Nd4j.getRandom().setSeed(12345); int nOut = 2; int width = 12; @@ -780,7 +806,9 @@ class CNNGradientCheckTest extends BaseDL4JTest { @Test @DisplayName("Test Depthwise Conv 2 D") - void testDepthwiseConv2D() { + @ParameterizedTest + @MethodSource("#params") + void testDepthwiseConv2D(CNN2DFormat format) { int nIn = 3; int depthMultiplier = 2; int nOut = nIn * depthMultiplier; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java index 0a280d9f0..5c4c42d62 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java @@ -39,8 +39,10 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -55,26 +57,22 @@ import java.io.File; import java.io.FileOutputStream; import java.io.InputStream; import java.nio.file.Path; +import java.util.Arrays; +import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -@RunWith(Parameterized.class) public class YoloGradientCheckTests extends BaseDL4JTest { static { Nd4j.setDataType(DataType.DOUBLE); } - private CNN2DFormat format; - public YoloGradientCheckTests(CNN2DFormat format){ - this.format = format; - } - @Parameterized.Parameters(name = "{0}") - public static Object[] params(){ - return CNN2DFormat.values(); - } + public static Stream params() { + return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of); + } @Override public long getTimeoutMilliseconds() { @@ -82,7 +80,9 @@ public class YoloGradientCheckTests extends BaseDL4JTest { } @Test - public void testYoloOutputLayer() { + @ParameterizedTest + @MethodSource("#params") + public void testYoloOutputLayer(CNN2DFormat format) { int depthIn = 2; int c = 3; int b = 3; @@ -159,13 +159,13 @@ public class YoloGradientCheckTests extends BaseDL4JTest { } } - private static INDArray yoloLabels(int mb, int c, int h, int w){ + private static INDArray yoloLabels(int mb, int c, int h, int w) { int labelDepth = 4 + c; INDArray labels = Nd4j.zeros(mb, labelDepth, h, w); //put 1 object per minibatch, at positions (0,0), (1,1) etc. //Positions for label boxes: (1,1) to (2,2), (2,2) to (4,4) etc - for( int i=0; i params(){ + return Arrays.asList(new DataType[]{DataType.FLOAT, DataType.DOUBLE}).stream().map(Arguments::of); } @Override @@ -74,7 +71,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testConv2d() { + @MethodSource("#params") + @ParameterizedTest + public void testConv2d(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { @@ -83,15 +82,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getConv2dNet(CNN2DFormat.NCHW, true, cm)) - .net2(getConv2dNet(CNN2DFormat.NCHW, false, cm)) - .net3(getConv2dNet(CNN2DFormat.NHWC, true, cm)) - .net4(getConv2dNet(CNN2DFormat.NHWC, false, cm)) + .net1(getConv2dNet(dataType,CNN2DFormat.NCHW, true, cm)) + .net2(getConv2dNet(dataType,CNN2DFormat.NCHW, false, cm)) + .net3(getConv2dNet(dataType,CNN2DFormat.NHWC, true, cm)) + .net4(getConv2dNet(dataType,CNN2DFormat.NHWC, false, cm)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -107,7 +106,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testSubsampling2d() { + @MethodSource("#params") + @ParameterizedTest + public void testSubsampling2d(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { @@ -116,15 +117,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm)) - .net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm)) - .net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm)) - .net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm)) + .net1(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, true, cm)) + .net2(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, false, cm)) + .net3(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, true, cm)) + .net4(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, false, cm)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -140,7 +141,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testDepthwiseConv2d() { + @MethodSource("#params") + @ParameterizedTest + public void testDepthwiseConv2d(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { @@ -149,15 +152,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm)) - .net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm)) - .net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm)) - .net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm)) + .net1(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, true, cm)) + .net2(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, false, cm)) + .net3(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, true, cm)) + .net4(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, false, cm)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -173,7 +176,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testSeparableConv2d() { + @MethodSource("#params") + @ParameterizedTest + public void testSeparableConv2d(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { @@ -182,15 +187,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm)) - .net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm)) - .net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm)) - .net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm)) + .net1(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, true, cm)) + .net2(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, false, cm)) + .net3(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, true, cm)) + .net4(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, false, cm)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -206,7 +211,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testDeconv2d() { + @MethodSource("#params") + @ParameterizedTest + public void testDeconv2d(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { @@ -215,15 +222,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm)) - .net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm)) - .net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm)) - .net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm)) + .net1(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, true, cm)) + .net2(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, false, cm)) + .net3(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, true, cm)) + .net4(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, false, cm)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -239,7 +246,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testLRN() { + @MethodSource("#params") + @ParameterizedTest + public void testLRN(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { @@ -248,15 +257,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getLrnLayer(CNN2DFormat.NCHW, true, cm)) - .net2(getLrnLayer(CNN2DFormat.NCHW, false, cm)) - .net3(getLrnLayer(CNN2DFormat.NHWC, true, cm)) - .net4(getLrnLayer(CNN2DFormat.NHWC, false, cm)) + .net1(getLrnLayer(dataType,CNN2DFormat.NCHW, true, cm)) + .net2(getLrnLayer(dataType,CNN2DFormat.NCHW, false, cm)) + .net3(getLrnLayer(dataType,CNN2DFormat.NHWC, true, cm)) + .net4(getLrnLayer(dataType,CNN2DFormat.NHWC, false, cm)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -272,7 +281,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testZeroPaddingLayer(){ + @MethodSource("#params") + @ParameterizedTest + public void testZeroPaddingLayer(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { Nd4j.getRandom().setSeed(12345); @@ -280,15 +291,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers" : "No helpers"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getZeroPaddingNet(CNN2DFormat.NCHW, true)) - .net2(getZeroPaddingNet(CNN2DFormat.NCHW, false)) - .net3(getZeroPaddingNet(CNN2DFormat.NHWC, true)) - .net4(getZeroPaddingNet(CNN2DFormat.NHWC, false)) + .net1(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, true)) + .net2(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, false)) + .net3(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, true)) + .net4(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, false)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -303,7 +314,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testCropping2DLayer(){ + @MethodSource("#params") + @ParameterizedTest + public void testCropping2DLayer(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { Nd4j.getRandom().setSeed(12345); @@ -311,15 +324,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers" : "No helpers"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getCropping2dNet(CNN2DFormat.NCHW, true)) - .net2(getCropping2dNet(CNN2DFormat.NCHW, false)) - .net3(getCropping2dNet(CNN2DFormat.NHWC, true)) - .net4(getCropping2dNet(CNN2DFormat.NHWC, false)) + .net1(getCropping2dNet(dataType,CNN2DFormat.NCHW, true)) + .net2(getCropping2dNet(dataType,CNN2DFormat.NCHW, false)) + .net3(getCropping2dNet(dataType,CNN2DFormat.NHWC, true)) + .net4(getCropping2dNet(dataType,CNN2DFormat.NHWC, false)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -334,7 +347,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testUpsampling2d(){ + @MethodSource("#params") + @ParameterizedTest + public void testUpsampling2d(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { Nd4j.getRandom().setSeed(12345); @@ -342,15 +357,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers" : "No helpers"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getUpsamplingNet(CNN2DFormat.NCHW, true)) - .net2(getUpsamplingNet(CNN2DFormat.NCHW, false)) - .net3(getUpsamplingNet(CNN2DFormat.NHWC, true)) - .net4(getUpsamplingNet(CNN2DFormat.NHWC, false)) + .net1(getUpsamplingNet(dataType,CNN2DFormat.NCHW, true)) + .net2(getUpsamplingNet(dataType,CNN2DFormat.NCHW, false)) + .net3(getUpsamplingNet(dataType,CNN2DFormat.NHWC, true)) + .net4(getUpsamplingNet(dataType,CNN2DFormat.NHWC, false)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -365,7 +380,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testBatchNormNet(){ + @MethodSource("#params") + @ParameterizedTest + public void testBatchNormNet(DataType dataType) { try { for(boolean useLogStd : new boolean[]{true, false}) { for (boolean helpers : new boolean[]{false, true}) { @@ -374,15 +391,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std"); System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true)) - .net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false)) - .net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true)) - .net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false)) + .net1(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, true)) + .net2(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, false)) + .net3(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, true)) + .net4(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, false)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -398,7 +415,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testCnnLossLayer() { + @MethodSource("#params") + @ParameterizedTest + public void testCnnLossLayer(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { Nd4j.getRandom().setSeed(12345); @@ -406,8 +425,8 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers" : "No helpers"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); - INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); + INDArray labelsNHWC = TestUtils.randomOneHot(dataType,2*6*6, 3); labelsNHWC = labelsNHWC.reshape(2,6,6,3); INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup(); @@ -434,7 +453,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testSpaceToDepthNet(){ + @MethodSource("#params") + @ParameterizedTest + public void testSpaceToDepthNet(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { Nd4j.getRandom().setSeed(12345); @@ -442,15 +463,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers" : "No helpers"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true)) - .net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false)) - .net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true)) - .net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false)) + .net1(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, true)) + .net2(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, false)) + .net3(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, true)) + .net4(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, false)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -465,7 +486,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testSpaceToBatchNet(){ + @MethodSource("#params") + @ParameterizedTest + public void testSpaceToBatchNet(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { Nd4j.getRandom().setSeed(12345); @@ -473,15 +496,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers" : "No helpers"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 16, 16); INDArray labels = TestUtils.randomOneHot(8, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true)) - .net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false)) - .net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true)) - .net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false)) + .net1(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, true)) + .net2(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, false)) + .net3(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, true)) + .net4(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, false)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -496,7 +519,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } @Test - public void testLocallyConnected() { + @MethodSource("#params") + @ParameterizedTest + public void testLocallyConnected(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { @@ -505,15 +530,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm)) - .net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm)) - .net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm)) - .net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm)) + .net1(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, true, cm)) + .net2(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, false, cm)) + .net3(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, true, cm)) + .net4(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, false, cm)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -530,7 +555,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { @Test - public void testGlobalPooling() { + @MethodSource("#params") + @ParameterizedTest + public void testGlobalPooling(DataType dataType) { try { for (boolean helpers : new boolean[]{false, true}) { for (PoolingType pt : PoolingType.values()) { @@ -539,15 +566,15 @@ public class ConvDataFormatTests extends BaseDL4JTest { String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")"; System.out.println(" --- " + msg + " ---"); - INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12); INDArray labels = TestUtils.randomOneHot(2, 10); TestCase tc = TestCase.builder() .msg(msg) - .net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true)) - .net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false)) - .net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true)) - .net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false)) + .net1(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, true)) + .net2(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, false)) + .net3(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, true)) + .net4(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, false)) .inNCHW(inNCHW) .labelsNCHW(labels) .labelsNHWC(labels) @@ -562,9 +589,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } } - private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + private MultiLayerNetwork getConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { if (setOnLayerAlso) { - return getNetWithLayer(new ConvolutionLayer.Builder() + return getNetWithLayer(dataType,new ConvolutionLayer.Builder() .kernelSize(3, 3) .stride(2, 2) .activation(Activation.TANH) @@ -573,7 +600,7 @@ public class ConvDataFormatTests extends BaseDL4JTest { .helperAllowFallback(false) .build(), format, cm, null); } else { - return getNetWithLayer(new ConvolutionLayer.Builder() + return getNetWithLayer(dataType,new ConvolutionLayer.Builder() .kernelSize(3, 3) .stride(2, 2) .activation(Activation.TANH) @@ -583,16 +610,16 @@ public class ConvDataFormatTests extends BaseDL4JTest { } } - private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + private MultiLayerNetwork getSubsampling2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { if (setOnLayerAlso) { - return getNetWithLayer(new SubsamplingLayer.Builder() + return getNetWithLayer(dataType,new SubsamplingLayer.Builder() .kernelSize(2, 2) .stride(1, 1) .dataFormat(format) .helperAllowFallback(false) .build(), format, cm, null); } else { - return getNetWithLayer(new SubsamplingLayer.Builder() + return getNetWithLayer(dataType,new SubsamplingLayer.Builder() .kernelSize(2, 2) .stride(1, 1) .helperAllowFallback(false) @@ -600,9 +627,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } } - private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + private MultiLayerNetwork getSeparableConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { if (setOnLayerAlso) { - return getNetWithLayer(new SeparableConvolution2D.Builder() + return getNetWithLayer(dataType,new SeparableConvolution2D.Builder() .kernelSize(3, 3) .stride(2, 2) .activation(Activation.TANH) @@ -611,7 +638,7 @@ public class ConvDataFormatTests extends BaseDL4JTest { .helperAllowFallback(false) .build(), format, cm, null); } else { - return getNetWithLayer(new SeparableConvolution2D.Builder() + return getNetWithLayer(dataType,new SeparableConvolution2D.Builder() .kernelSize(3, 3) .stride(2, 2) .activation(Activation.TANH) @@ -621,9 +648,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } } - private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + private MultiLayerNetwork getDepthwiseConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { if (setOnLayerAlso) { - return getNetWithLayer(new DepthwiseConvolution2D.Builder() + return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder() .depthMultiplier(2) .kernelSize(3, 3) .stride(2, 2) @@ -633,7 +660,7 @@ public class ConvDataFormatTests extends BaseDL4JTest { .helperAllowFallback(false) .build(), format, cm, null); } else { - return getNetWithLayer(new DepthwiseConvolution2D.Builder() + return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder() .depthMultiplier(2) .kernelSize(3, 3) .stride(2, 2) @@ -644,59 +671,59 @@ public class ConvDataFormatTests extends BaseDL4JTest { } } - private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + private MultiLayerNetwork getLrnLayer(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { if (setOnLayerAlso) { - return getNetWithLayer(new LocalResponseNormalization.Builder() + return getNetWithLayer(dataType,new LocalResponseNormalization.Builder() .dataFormat(format) .helperAllowFallback(false) .build(), format, cm, null); } else { - return getNetWithLayer(new LocalResponseNormalization.Builder() + return getNetWithLayer(dataType,new LocalResponseNormalization.Builder() .helperAllowFallback(false) .build(), format, cm, null); } } - private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) { + private MultiLayerNetwork getZeroPaddingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) { if (setOnLayerAlso) { - return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2) + return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2) .dataFormat(format).build(), format, ConvolutionMode.Same, null); } else { - return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(), + return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2).build(), format, ConvolutionMode.Same, null); } } - private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) { + private MultiLayerNetwork getCropping2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) { if (setOnLayerAlso) { - return getNetWithLayer(new Cropping2D.Builder(2,2) + return getNetWithLayer(dataType,new Cropping2D.Builder(2,2) .dataFormat(format).build(), format, ConvolutionMode.Same, null); } else { - return getNetWithLayer(new Cropping2D.Builder(2,2) + return getNetWithLayer(dataType,new Cropping2D.Builder(2,2) .build(), format, ConvolutionMode.Same, null); } } - private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) { + private MultiLayerNetwork getUpsamplingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) { if (setOnLayerAlso) { - return getNetWithLayer(new Upsampling2D.Builder(2) + return getNetWithLayer(dataType,new Upsampling2D.Builder(2) .dataFormat(format).build(), format, ConvolutionMode.Same, null); } else { - return getNetWithLayer(new Upsampling2D.Builder(2) + return getNetWithLayer(dataType,new Upsampling2D.Builder(2) .build(), format, ConvolutionMode.Same, null); } } - private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + private MultiLayerNetwork getDeconv2DNet2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { if (setOnLayerAlso) { - return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) + return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2) .activation(Activation.TANH) .kernelSize(2,2) .dataFormat(format) .stride(2,2) .build(), format, cm, null); } else { - return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) + return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2) .activation(Activation.TANH) .kernelSize(2,2) .dataFormat(format) @@ -705,50 +732,50 @@ public class ConvDataFormatTests extends BaseDL4JTest { } } - private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) { + private MultiLayerNetwork getBatchNormNet(DataType dataType,boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) { if (setOnLayerAlso) { - return getNetWithLayer(new BatchNormalization.Builder() + return getNetWithLayer(dataType,new BatchNormalization.Builder() .useLogStd(logStdev) .dataFormat(format) .helperAllowFallback(false) .nOut(3).build(), format, ConvolutionMode.Same, null); } else { - return getNetWithLayer(new BatchNormalization.Builder() + return getNetWithLayer(dataType,new BatchNormalization.Builder() .useLogStd(logStdev) .helperAllowFallback(false) .nOut(3).build(), format, ConvolutionMode.Same, null); } } - private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) { + private MultiLayerNetwork getSpaceToDepthNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) { if (setOnLayerAlso) { - return getNetWithLayer(new SpaceToDepthLayer.Builder() + return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder() .blocks(2) .dataFormat(format) .build(), format, ConvolutionMode.Same, null); } else { - return getNetWithLayer(new SpaceToDepthLayer.Builder() + return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder() .blocks(2) .build(), format, ConvolutionMode.Same, null); } } - private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) { + private MultiLayerNetwork getSpaceToBatchNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) { if (setOnLayerAlso) { - return getNetWithLayer(new SpaceToBatchLayer.Builder() + return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder() .blocks(2, 2) .dataFormat(format) .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); } else { - return getNetWithLayer(new SpaceToBatchLayer.Builder() + return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder() .blocks(2, 2) .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); } } - private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + private MultiLayerNetwork getLocallyConnectedNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { if (setOnLayerAlso) { - return getNetWithLayer(new LocallyConnected2D.Builder() + return getNetWithLayer(dataType,new LocallyConnected2D.Builder() .kernelSize(3, 3) .stride(2, 2) .activation(Activation.TANH) @@ -756,7 +783,7 @@ public class ConvDataFormatTests extends BaseDL4JTest { .nOut(3) .build(), format, cm, null); } else { - return getNetWithLayer(new LocallyConnected2D.Builder() + return getNetWithLayer(dataType,new LocallyConnected2D.Builder() .kernelSize(3, 3) .stride(2, 2) .activation(Activation.TANH) @@ -765,9 +792,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { } } - private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) { + private MultiLayerNetwork getNetWithLayer(DataType dataType,Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) { NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() - .dataType(this.dataType) + .dataType(dataType) .seed(12345) .convolutionMode(cm) .list() @@ -794,13 +821,13 @@ public class ConvDataFormatTests extends BaseDL4JTest { return net; } - private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) { + private MultiLayerNetwork getGlobalPoolingNet(DataType dataType,CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) { if (setOnLayerAlso) { - return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) + return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt) .poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2}) .build(), format, ConvolutionMode.Same, null); } else { - return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) + return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt) .build(), format, ConvolutionMode.Same, null); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java index d8a95c452..7ca17f048 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java @@ -45,8 +45,11 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.TimeSeriesUtils; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.enums.RnnDataFormat; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -61,30 +64,29 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.util.Arrays; +import java.util.stream.Stream; + import static org.deeplearning4j.nn.conf.RNNFormat.NCW; import static org.junit.jupiter.api.Assertions.assertEquals; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.extension.ExtendWith; @Slf4j -@RunWith(Parameterized.class) @DisplayName("Bidirectional Test") class BidirectionalTest extends BaseDL4JTest { - private RNNFormat rnnDataFormat; - public BidirectionalTest(RNNFormat rnnDataFormat) { - this.rnnDataFormat = rnnDataFormat; - } - @Parameterized.Parameters - public static Object[] params() { - return RNNFormat.values(); + public static Stream params() { + return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); } @Test @DisplayName("Compare Implementations") - void compareImplementations() { + @ParameterizedTest + @MethodSource("#params") + void compareImplementations(RNNFormat rnnDataFormat) { for (WorkspaceMode wsm : WorkspaceMode.values()) { log.info("*** Starting workspace mode: " + wsm); // Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params @@ -147,9 +149,11 @@ class BidirectionalTest extends BaseDL4JTest { } } - @Test @DisplayName("Compare Implementations Comp Graph") - void compareImplementationsCompGraph() { + @Test + @ParameterizedTest + @MethodSource("#params") + void compareImplementationsCompGraph(RNNFormat rnnFormat) { // for(WorkspaceMode wsm : WorkspaceMode.values()) { for (WorkspaceMode wsm : new WorkspaceMode[] { WorkspaceMode.NONE, WorkspaceMode.ENABLED }) { log.info("*** Starting workspace mode: " + wsm); @@ -187,8 +191,8 @@ class BidirectionalTest extends BaseDL4JTest { Gradient g2 = net2.gradient(); assertEquals(g1.gradient(), g2.gradient()); // Ensure updates are equal: - ComputationGraphUpdater u1 = (ComputationGraphUpdater) net1.getUpdater(); - ComputationGraphUpdater u2 = (ComputationGraphUpdater) net2.getUpdater(); + ComputationGraphUpdater u1 = net1.getUpdater(); + ComputationGraphUpdater u2 = net2.getUpdater(); assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray()); u1.update(g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); u2.update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); @@ -205,7 +209,9 @@ class BidirectionalTest extends BaseDL4JTest { @Test @DisplayName("Test Serialization") - void testSerialization() throws Exception { + @ParameterizedTest + @MethodSource("#params") + void testSerialization(RNNFormat rnnDataFormat) throws Exception { for (WorkspaceMode wsm : WorkspaceMode.values()) { log.info("*** Starting workspace mode: " + wsm); Nd4j.getRandom().setSeed(12345); @@ -242,7 +248,9 @@ class BidirectionalTest extends BaseDL4JTest { @Test @DisplayName("Test Serialization Comp Graph") - void testSerializationCompGraph() throws Exception { + @ParameterizedTest + @MethodSource("#params") + void testSerializationCompGraph(RNNFormat rnnDataFormat) throws Exception { for (WorkspaceMode wsm : WorkspaceMode.values()) { log.info("*** Starting workspace mode: " + wsm); Nd4j.getRandom().setSeed(12345); @@ -277,7 +285,9 @@ class BidirectionalTest extends BaseDL4JTest { @Test @DisplayName("Test Simple Bidirectional") - void testSimpleBidirectional() { + @ParameterizedTest + @MethodSource("#params") + public void testSimpleBidirectional(RNNFormat rnnDataFormat) { for (WorkspaceMode wsm : WorkspaceMode.values()) { log.info("*** Starting workspace mode: " + wsm); Nd4j.getRandom().setSeed(12345); @@ -362,7 +372,9 @@ class BidirectionalTest extends BaseDL4JTest { @Test @DisplayName("Test Simple Bidirectional Comp Graph") - void testSimpleBidirectionalCompGraph() { + @ParameterizedTest + @MethodSource("#params") + void testSimpleBidirectionalCompGraph(RNNFormat rnnDataFormat) { for (WorkspaceMode wsm : WorkspaceMode.values()) { log.info("*** Starting workspace mode: " + wsm); Nd4j.getRandom().setSeed(12345); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java index 41b91b65a..5f7ef46b3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java @@ -19,7 +19,6 @@ */ package org.deeplearning4j.nn.layers.recurrent; -import junit.framework.TestCase; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.api.OptimizationAlgorithm; @@ -34,9 +33,12 @@ import org.deeplearning4j.nn.params.GravesBidirectionalLSTMParamInitializer; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.api.ndarray.INDArray; @@ -44,31 +46,29 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; -import org.nd4j.common.primitives.Pair; -import static org.junit.jupiter.api.Assertions.*; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.extension.ExtendWith; -@RunWith(Parameterized.class) +import java.util.Arrays; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.*; + @DisplayName("Graves Bidirectional LSTM Test") class GravesBidirectionalLSTMTest extends BaseDL4JTest { private double score = 0.0; - private RNNFormat rnnDataFormat; - public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat) { - this.rnnDataFormat = rnnDataFormat; + + public static Stream params(){ + return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); } - @Parameterized.Parameters - public static Object[] params() { - return RNNFormat.values(); - } @Test @DisplayName("Test Bidirectional LSTM Graves Forward Basic") - void testBidirectionalLSTMGravesForwardBasic() { + @MethodSource("#params") + @ParameterizedTest + void testBidirectionalLSTMGravesForwardBasic(RNNFormat rnnDataFormat) { // Very basic test of forward prop. of LSTM layer with a time series. // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. int nIn = 13; @@ -110,19 +110,21 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest { @Test @DisplayName("Test Bidirectional LSTM Graves Backward Basic") - void testBidirectionalLSTMGravesBackwardBasic() { + @MethodSource("#params") + @ParameterizedTest + void testBidirectionalLSTMGravesBackwardBasic(RNNFormat rnnDataFormat) { // Very basic test of backprop for mini-batch + time series // Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape. - testGravesBackwardBasicHelper(13, 3, 17, 10, 7); + testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 7); // Edge case: miniBatchSize = 1 - testGravesBackwardBasicHelper(13, 3, 17, 1, 7); + testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 7); // Edge case: timeSeriesLength = 1 - testGravesBackwardBasicHelper(13, 3, 17, 10, 1); + testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 1); // Edge case: both miniBatchSize = 1 and timeSeriesLength = 1 - testGravesBackwardBasicHelper(13, 3, 17, 1, 1); + testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 1); } - private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) { + private void testGravesBackwardBasicHelper(RNNFormat rnnDataFormat,int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) { INDArray inputData = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.ones(miniBatchSize, nIn, timeSeriesLength) : Nd4j.ones(miniBatchSize, timeSeriesLength, nIn); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build(); long numParams = conf.getLayer().initializer().numParams(conf); @@ -204,7 +206,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest { @Test @DisplayName("Test Get Set Parmas") - void testGetSetParmas() { + @MethodSource("#params") + @ParameterizedTest + void testGetSetParmas(RNNFormat rnnDataFormat) { final int nIn = 2; final int layerSize = 3; final int miniBatchSize = 2; @@ -224,7 +228,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest { @Test @DisplayName("Test Simple Forwards And Backwards Activation") - void testSimpleForwardsAndBackwardsActivation() { + @MethodSource("#params") + @ParameterizedTest + void testSimpleForwardsAndBackwardsActivation(RNNFormat rnnDataFormat) { final int nIn = 2; final int layerSize = 3; final int miniBatchSize = 1; @@ -342,7 +348,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest { @Test @DisplayName("Test Gate Activation Fns Sanity Check") - void testGateActivationFnsSanityCheck() { + @MethodSource("#params") + @ParameterizedTest + void testGateActivationFnsSanityCheck(RNNFormat rnnDataFormat) { for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java index dad304dac..300386448 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java @@ -30,36 +30,35 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.TrainingListener; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.Arrays; import java.util.Collections; +import java.util.stream.Stream; + import static org.junit.jupiter.api.Assertions.assertEquals; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.extension.ExtendWith; -@RunWith(Parameterized.class) @DisplayName("Mask Zero Layer Test") class MaskZeroLayerTest extends BaseDL4JTest { - private RNNFormat rnnDataFormat; - public MaskZeroLayerTest(RNNFormat rnnDataFormat) { - this.rnnDataFormat = rnnDataFormat; + public static Stream params() { + return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); } - @Parameterized.Parameters - public static Object[] params() { - return RNNFormat.values(); - } - - @Test @DisplayName("Activate") - void activate() { + @Test + @ParameterizedTest + @MethodSource("#params") + void activate(RNNFormat rnnDataFormat) { // GIVEN two examples where some of the timesteps are zero. INDArray ex1 = Nd4j.create(new double[][] { new double[] { 0, 3, 5 }, new double[] { 0, 0, 2 } }); INDArray ex2 = Nd4j.create(new double[][] { new double[] { 0, 0, 2 }, new double[] { 0, 0, 2 } }); @@ -95,9 +94,12 @@ class MaskZeroLayerTest extends BaseDL4JTest { assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6); } - @Test + @DisplayName("Test Serialization") - void testSerialization() { + @Test + @ParameterizedTest + @MethodSource("#params") + void testSerialization(RNNFormat rnnDataFormat) { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder().setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java index 118fbf6b3..b21c0ffc2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java @@ -40,8 +40,10 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -51,30 +53,31 @@ import org.nd4j.common.primitives.Pair; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) @AllArgsConstructor public class RnnDataFormatTests extends BaseDL4JTest { - private boolean helpers; - private boolean lastTimeStep; - private boolean maskZeros; - @Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}") - public static List params(){ + public static Stream params() { List ret = new ArrayList<>(); for (boolean helpers: new boolean[]{true, false}) for (boolean lastTimeStep: new boolean[]{true, false}) for (boolean maskZero: new boolean[]{true, false}) ret.add(new Object[]{helpers, lastTimeStep, maskZero}); - return ret; + return ret.stream().map(Arguments::of); } @Test - public void testSimpleRnn() { + @MethodSource("#params") + @ParameterizedTest + public void testSimpleRnn(boolean helpers, + boolean lastTimeStep, + boolean maskZeros + ) { try { Nd4j.getRandom().setSeed(12345); @@ -107,7 +110,11 @@ public class RnnDataFormatTests extends BaseDL4JTest { } @Test - public void testLSTM() { + @ParameterizedTest + @MethodSource("#params") + public void testLSTM(boolean helpers, + boolean lastTimeStep, + boolean maskZeros) { try { Nd4j.getRandom().setSeed(12345); @@ -141,7 +148,11 @@ public class RnnDataFormatTests extends BaseDL4JTest { @Test - public void testGraveLSTM() { + @MethodSource("#params") + @ParameterizedTest + public void testGraveLSTM(boolean helpers, + boolean lastTimeStep, + boolean maskZeros) { try { Nd4j.getRandom().setSeed(12345); @@ -175,7 +186,11 @@ public class RnnDataFormatTests extends BaseDL4JTest { @Test - public void testGraveBiLSTM() { + @MethodSource("#params") + @ParameterizedTest + public void testGraveBiLSTM(boolean helpers, + boolean lastTimeStep, + boolean maskZeros) { try { Nd4j.getRandom().setSeed(12345); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java index 66a87c872..65f8c98f0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java @@ -34,14 +34,20 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.graph.ComputationGraph; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.enums.RnnDataFormat; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.learning.config.AdaGrad; +import java.util.Arrays; +import java.util.stream.Stream; + import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; import static org.deeplearning4j.nn.weights.WeightInit.XAVIER_UNIFORM; import static org.junit.jupiter.api.Assertions.*; @@ -50,20 +56,16 @@ import static org.nd4j.linalg.activations.Activation.TANH; import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE; -@RunWith(Parameterized.class) public class TestLastTimeStepLayer extends BaseDL4JTest { - private RNNFormat rnnDataFormat; - public TestLastTimeStepLayer(RNNFormat rnnDataFormat){ - this.rnnDataFormat = rnnDataFormat; - } - @Parameterized.Parameters(name="{0}") - public static Object[] params(){ - return RNNFormat.values(); + public static Stream params(){ + return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); } @Test - public void testLastTimeStepVertex() { + @ParameterizedTest + @MethodSource("#params") + public void testLastTimeStepVertex(RNNFormat rnnDataFormat) { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") .addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder() @@ -126,7 +128,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest { } @Test - public void testMaskingAndAllMasked(){ + @ParameterizedTest + @MethodSource("#params") + public void testMaskingAndAllMasked(RNNFormat rnnDataFormat) { ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() .optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT) .weightInit(XAVIER_UNIFORM) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index dba4ae308..11920bce9 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -36,8 +36,11 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.enums.RnnDataFormat; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -49,25 +52,23 @@ import org.nd4j.common.primitives.Pair; import java.util.Arrays; import java.util.List; import java.util.Random; +import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -@RunWith(Parameterized.class) public class TestRnnLayers extends BaseDL4JTest { - private RNNFormat rnnDataFormat; - public TestRnnLayers(RNNFormat rnnDataFormat){ - this.rnnDataFormat = rnnDataFormat; - } - @Parameterized.Parameters - public static Object[] params(){ - return RNNFormat.values(); + public static Stream params(){ + return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); } + @Test - public void testTimeStepIs3Dimensional() { + @ParameterizedTest + @MethodSource("#params") + public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat) { int nIn = 12; int nOut = 3; @@ -117,7 +118,9 @@ public class TestRnnLayers extends BaseDL4JTest { } @Test - public void testDropoutRecurrentLayers(){ + @ParameterizedTest + @MethodSource("#params") + public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat){ Nd4j.getRandom().setSeed(12345); String[] layerTypes = new String[]{"graves", "lstm", "simple"}; @@ -215,9 +218,11 @@ public class TestRnnLayers extends BaseDL4JTest { } @Test - public void testMismatchedInputLabelLength(){ + @ParameterizedTest + @MethodSource("#params") + public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat){ - for( int i=0; i<2; i++ ){ + for( int i = 0; i < 2; i++) { NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java index a316ac858..58af7fe4b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java @@ -29,8 +29,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -38,25 +40,25 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.ops.transforms.Transforms; +import java.util.Arrays; +import java.util.stream.Stream; + import static org.junit.jupiter.api.Assertions.assertEquals; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.point; -@RunWith(Parameterized.class) public class TestSimpleRnn extends BaseDL4JTest { - private RNNFormat rnnDataFormat; - public TestSimpleRnn(RNNFormat rnnDataFormat){ - this.rnnDataFormat = rnnDataFormat; - } - @Parameterized.Parameters - public static Object[] params(){ - return RNNFormat.values(); + public static Stream params() { + return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); } + @Test - public void testSimpleRnn(){ + @ParameterizedTest + @MethodSource("#params") + public void testSimpleRnn(RNNFormat rnnDataFormat) { Nd4j.getRandom().setSeed(12345); int m = 3; @@ -125,7 +127,9 @@ public class TestSimpleRnn extends BaseDL4JTest { } @Test - public void testBiasInit(){ + @ParameterizedTest + @MethodSource("#params") + public void testBiasInit(RNNFormat rnnDataFormat) { Nd4j.getRandom().setSeed(12345); int nIn = 5; int layerSize = 6; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java index acae4faf3..44ce4c383 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java @@ -37,8 +37,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -47,22 +49,22 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.Arrays; +import java.util.stream.Stream; + import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) public class TestTimeDistributed extends BaseDL4JTest { - private RNNFormat rnnDataFormat; - public TestTimeDistributed(RNNFormat rnnDataFormat){ - this.rnnDataFormat = rnnDataFormat; - } - @Parameterized.Parameters - public static Object[] params(){ - return RNNFormat.values(); + public static Stream params(){ + return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of); } + @Test - public void testTimeDistributed(){ + @ParameterizedTest + @MethodSource("#params") + public void testTimeDistributed(RNNFormat rnnDataFormat){ for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) { MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() @@ -133,10 +135,12 @@ public class TestTimeDistributed extends BaseDL4JTest { @Test - public void testTimeDistributedDense(){ + @MethodSource("#params") + @ParameterizedTest + public void testTimeDistributedDense(RNNFormat rnnDataFormat){ - for( int rnnType=0; rnnType<3; rnnType++ ) { - for( int ffType=0; ffType<3; ffType++ ) { + for( int rnnType = 0; rnnType < 3; rnnType++ ) { + for( int ffType = 0; ffType < 3; ffType++ ) { Layer l0, l2; switch (rnnType) { diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/ConvDataFormatTests.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/ConvDataFormatTests.java index dad81fbd0..eab97fd66 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/ConvDataFormatTests.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/convolution/ConvDataFormatTests.java @@ -39,8 +39,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index 66f521405..0d55475e6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -145,92 +145,6 @@ - - - org.jetbrains.kotlin - kotlin-maven-plugin - 1.4.30-M1 - - - -Xjsr305=strict - - - spring - jpa - - - - - org.jetbrains.kotlin - kotlin-maven-allopen - ${kotlin.version} - - - org.jetbrains.kotlin - kotlin-maven-noarg - ${kotlin.version} - - - - - compile - compile - - - ${project.basedir}/src/main/stubs - ${project.basedir}/src/main/kotlin - ${project.basedir}/src/main/java - ${project.basedir}/src/main/ops - - - - - test-compile - test-compile - - - ${project.basedir}/src/test/stubs - ${project.basedir}/src/test/kotlin - ${project.basedir}/src/test/java - ${project.basedir}/src/test/ops - - - - - - - - - org.apache.maven.plugins - maven-compiler-plugin - 3.5.1 - - - - default-compile - none - - - - default-testCompile - none - - - java-compile - compile - compile - - - java-test-compile - test-compile - testCompile - - - - ${java.version} - ${java.version} - - @@ -244,7 +158,10 @@ org.junit.jupiter junit-jupiter-engine - + + org.junit.jupiter + junit-jupiter-params + org.jetbrains.kotlin kotlin-stdlib-jdk8 @@ -261,11 +178,14 @@ org.nd4j samediff-import-tensorflow ${project.version} + compile org.nd4j samediff-import-onnx ${project.version} + compile + org.nd4j diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java index db1f6a270..faba74a48 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java @@ -22,11 +22,6 @@ package org.nd4j; import lombok.extern.slf4j.Slf4j; import org.nd4j.common.tests.AbstractAssertTestsClass; import org.nd4j.common.tests.BaseND4JTest; -import org.nd4j.imports.tfgraphs.TFGraphTestAllLibnd4j; -import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff; -import org.nd4j.imports.tfgraphs.TFGraphTestList; -import org.nd4j.imports.tfgraphs.TFGraphTestZooModels; -import org.nd4j.imports.listeners.ImportModelDebugger; import java.util.*; @Slf4j @@ -36,11 +31,6 @@ public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { protected Set> getExclusions() { //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) return new HashSet<>(Arrays.asList( - TFGraphTestAllSameDiff.class, - TFGraphTestAllLibnd4j.class, - TFGraphTestList.class, - TFGraphTestZooModels.class, - ImportModelDebugger.class //Run manually only, otherwise ignored )); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java index 2f918d639..7294833ee 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java @@ -20,19 +20,16 @@ package org.nd4j; -import org.bytedeco.javacpp.Loader; import org.junit.AfterClass; -import org.junit.BeforeClass; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; import org.junit.runner.RunWith; import org.junit.runners.Suite; import org.nd4j.autodiff.opvalidation.*; import org.nd4j.autodiff.validation.OpValidation; -import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff; +//import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.function.Function; import static org.junit.Assume.assumeFalse; @@ -49,7 +46,7 @@ import static org.junit.Assume.assumeFalse; TransformOpValidation.class, //TF import tests - TFGraphTestAllSameDiff.class + //TFGraphTestAllSameDiff.class //TFGraphTestAllLibnd4j.class }) //IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test" diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java index 9700ed253..3f2a5c689 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java @@ -27,10 +27,12 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.converters.ImportClassMapping; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.NoOp; import org.nd4j.linalg.api.ops.compat.CompatSparseToDense; @@ -122,13 +124,11 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; @Disabled("No longer relevant after model import rewrite.") -public class TestOpMapping extends BaseNd4jTest { +public class TestOpMapping extends BaseNd4jTestWithBackends { Set> subTypes; - public TestOpMapping(Nd4jBackend b){ - super(b); - + public TestOpMapping() { Reflections reflections = new Reflections("org.nd4j"); subTypes = reflections.getSubTypesOf(DifferentialFunction.class); } @@ -146,6 +146,8 @@ public class TestOpMapping extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testOpMappingCoverage() throws Exception { Map opNameMapping = ImportClassMapping.getOpNameMapping(); Map tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions(); @@ -196,7 +198,9 @@ public class TestOpMapping extends BaseNd4jTest { } @Test - public void testOpsInNamespace() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOpsInNamespace(Nd4jBackend backend) throws Exception { //Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't // want to add to a namespace for some reason) //Note that we ignore "*Bp", "*Gradient", "*Derivative" etc ops @@ -354,8 +358,11 @@ public class TestOpMapping extends BaseNd4jTest { s.add(Assign.class); } - @Test @Disabled - public void generateOpClassList() throws Exception{ + @Test + @Disabled + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void generateOpClassList(Nd4jBackend backend) throws Exception{ Reflections reflections = new Reflections("org.nd4j"); Set> subTypes = reflections.getSubTypesOf(DifferentialFunction.class); @@ -366,12 +373,7 @@ public class TestOpMapping extends BaseNd4jTest { l.add(c); } - Collections.sort(l, new Comparator>() { - @Override - public int compare(Class o1, Class o2) { - return o1.getName().compareTo(o2.getName()); - } - }); + Collections.sort(l, Comparator.comparing(Class::getName)); for(Class c : l){ System.out.println(c.getName() + ".class,"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java index c48727018..d260be072 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java @@ -22,6 +22,8 @@ package org.nd4j.autodiff; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SDVariable; @@ -31,7 +33,7 @@ import org.nd4j.autodiff.samediff.internal.FrameIter; import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.memory.NoOpMemoryMgr; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -46,19 +48,17 @@ import java.util.Map; import static org.junit.jupiter.api.Assertions.*; -public class TestSessions extends BaseNd4jTest { - - public TestSessions(Nd4jBackend b){ - super(b); - } +public class TestSessions extends BaseNd4jTestWithBackends { @Override - public char ordering(){ + public char ordering() { return 'c'; } @Test - public void testInferenceSessionBasic(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInferenceSessionBasic(Nd4jBackend backend) { //So far: trivial test to check execution order SameDiff sd = SameDiff.create(); @@ -90,7 +90,9 @@ public class TestSessions extends BaseNd4jTest { @Test - public void testInferenceSessionBasic2(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInferenceSessionBasic2(Nd4jBackend backend) { //So far: trivial test to check execution order SameDiff sd = SameDiff.create(); @@ -126,7 +128,9 @@ public class TestSessions extends BaseNd4jTest { } @Test - public void testMergeSimple(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergeSimple(Nd4jBackend backend) { //This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available... SameDiff sd = SameDiff.create(); @@ -162,7 +166,9 @@ public class TestSessions extends BaseNd4jTest { @Test - public void testSwitchSimple(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSwitchSimple(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3,3); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java index 68374d35b..26e4567ff 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/internal/TestDependencyTracker.java @@ -21,10 +21,12 @@ package org.nd4j.autodiff.internal; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.internal.DependencyList; import org.nd4j.autodiff.samediff.internal.DependencyTracker; import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -35,19 +37,18 @@ import java.util.Collections; import static junit.framework.TestCase.assertNotNull; import static org.junit.jupiter.api.Assertions.*; -public class TestDependencyTracker extends BaseNd4jTest { +public class TestDependencyTracker extends BaseNd4jTestWithBackends { - public TestDependencyTracker(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { return 'c'; } - @Test - public void testSimple(){ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSimple(Nd4jBackend backend){ DependencyTracker dt = new DependencyTracker<>(); @@ -93,8 +94,10 @@ public class TestDependencyTracker extends BaseNd4jTest { assertTrue(dt.isEmpty()); } - @Test - public void testSatisfiedBeforeAdd(){ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSatisfiedBeforeAdd(Nd4jBackend backend){ DependencyTracker dt = new DependencyTracker<>(); //Check different order of adding dependencies: i.e., mark X as satisfied, then add x -> y dependency @@ -132,8 +135,10 @@ public class TestDependencyTracker extends BaseNd4jTest { assertFalse(dt.hasNewAllSatisfied()); } - @Test - public void testMarkUnsatisfied(){ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMarkUnsatisfied(Nd4jBackend backend){ DependencyTracker dt = new DependencyTracker<>(); dt.addDependency("y", "x"); @@ -164,7 +169,9 @@ public class TestDependencyTracker extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIdentityDependencyTracker(){ IdentityDependencyTracker dt = new IdentityDependencyTracker<>(); assertTrue(dt.isEmpty()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java index c3ed6099d..66467ed62 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java @@ -21,6 +21,8 @@ package org.nd4j.autodiff.opvalidation; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.validation.GradCheckUtil; @@ -38,12 +40,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class ActivationGradChecks extends BaseOpValidation { - public ActivationGradChecks(Nd4jBackend backend) { - super(backend); - } @Test - public void testActivationGradientCheck1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testActivationGradientCheck1(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4)); @@ -61,7 +62,9 @@ public class ActivationGradChecks extends BaseOpValidation { } @Test - public void testActivationGradientCheck2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testActivationGradientCheck2(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/BaseOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/BaseOpValidation.java index 3d22d0ded..efcd0e7d2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/BaseOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/BaseOpValidation.java @@ -21,18 +21,14 @@ package org.nd4j.autodiff.opvalidation; import org.junit.jupiter.api.BeforeEach; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; -public abstract class BaseOpValidation extends BaseNd4jTest { +public abstract class BaseOpValidation extends BaseNd4jTestWithBackends { - private DataType initialType; + private DataType initialType = Nd4j.dataType(); - public BaseOpValidation(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index e104242d5..9f78afa5b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -27,6 +27,8 @@ import java.util.List; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.validation.OpValidation; @@ -65,9 +67,6 @@ import static org.junit.jupiter.api.Assertions.*; @Slf4j public class LayerOpValidation extends BaseOpValidation { - public LayerOpValidation(Nd4jBackend backend) { - super(backend); - } @Override public long getTimeoutMilliseconds() { @@ -75,7 +74,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testXwPlusB() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testXwPlusB(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sameDiff = SameDiff.create(); @@ -109,7 +110,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testReluLayer() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReluLayer(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sameDiff = SameDiff.create(); @@ -137,7 +140,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testBiasAdd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBiasAdd(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sameDiff = SameDiff.create(); @@ -161,7 +166,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testConv2d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv2d(Nd4jBackend backend) { //avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling //Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d @@ -301,7 +308,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testLrn2d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLrn2d(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}}; @@ -342,7 +351,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testIm2Col() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2Col(Nd4jBackend backend) { //OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873 Nd4j.getRandom().setSeed(12345); @@ -381,7 +392,9 @@ public class LayerOpValidation extends BaseOpValidation { @Test - public void testOutputShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOutputShape(Nd4jBackend backend) { long[] inSize = {1, 8, 8, 3}; SameDiff sd = SameDiff.create(); @@ -431,7 +444,9 @@ public class LayerOpValidation extends BaseOpValidation { @Test - public void testAvgPool() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAvgPool(Nd4jBackend backend) { long[] inSize = {1, 8, 8, 3}; //NHWC Pooling2DConfig conf = Pooling2DConfig.builder() @@ -474,7 +489,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testConv3d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv3d(Nd4jBackend backend) { //Pooling3d, Conv3D, batch norm Nd4j.getRandom().setSeed(12345); @@ -576,7 +593,9 @@ public class LayerOpValidation extends BaseOpValidation { @Test - public void testDepthWiseConv2dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDepthWiseConv2dBasic(Nd4jBackend backend) { int nIn = 3; int depthWise = 4; int kH = 2; @@ -615,7 +634,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testSeparableConv2dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSeparableConv2dBasic(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int nIn = 2; int nOut = 3; @@ -671,7 +692,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testDeconv2dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDeconv2dBasic(Nd4jBackend backend) { int nIn = 2; int nOut = 3; int kH = 2; @@ -715,7 +738,9 @@ public class LayerOpValidation extends BaseOpValidation { @Test - public void testConv2dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv2dBasic(Nd4jBackend backend) { int nIn = 3; int nOut = 4; int kH = 2; @@ -756,7 +781,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testMaxPoolingArgMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMaxPoolingArgMax(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int nIn = 3; int kH = 2; @@ -785,7 +812,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testMaxPooling2dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMaxPooling2dBasic(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int nIn = 3; int kH = 2; @@ -843,7 +872,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testAvgPooling2dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAvgPooling2dBasic(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int nIn = 3; int kH = 2; @@ -892,7 +923,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testAvgPooling3dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAvgPooling3dBasic(Nd4jBackend backend) { int nIn = 3; int kH = 2; int kW = 2; @@ -929,7 +962,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testMaxPooling3dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMaxPooling3dBasic(Nd4jBackend backend) { int nIn = 3; int kH = 2; int kW = 2; @@ -967,7 +1002,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testConv1dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv1dBasic(Nd4jBackend backend) { int nIn = 3; int nOut = 4; int k = 2; @@ -1002,7 +1039,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testConv1dCausal() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv1dCausal(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int nIn = 3; int nOut = 4; @@ -1051,7 +1090,9 @@ public class LayerOpValidation extends BaseOpValidation { @Test - public void testConv1dForward() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv1dForward(Nd4jBackend backend) { int nIn = 2; int nOut = 1; int kernel = 3; @@ -1094,7 +1135,9 @@ public class LayerOpValidation extends BaseOpValidation { @Test - public void testConv3dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv3dBasic(Nd4jBackend backend) { int nIn = 3; int nOut = 4; int kH = 2; @@ -1140,7 +1183,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testDeConv3dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDeConv3dBasic(Nd4jBackend backend) { int nIn = 4; int nOut = 3; int kH = 2; @@ -1185,7 +1230,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testLayerNorm() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLayerNorm(Nd4jBackend backend) { final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); @@ -1210,7 +1257,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testLayerNorm4d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLayerNorm4d(Nd4jBackend backend) { int mb = 3; int ch = 4; for (boolean nchw : new boolean[]{true, false}) { @@ -1242,7 +1291,9 @@ public class LayerOpValidation extends BaseOpValidation { @Test - public void testLayerNormOP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLayerNormOP(Nd4jBackend backend) { final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); @@ -1258,7 +1309,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testLayerNormNoBias() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLayerNormNoBias(Nd4jBackend backend) { final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); @@ -1281,7 +1334,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testLayerNormOPNoBias() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLayerNormOPNoBias(Nd4jBackend backend) { final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); @@ -1296,7 +1351,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testLayerNormNoDeviation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLayerNormNoDeviation(Nd4jBackend backend) { final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); for (int i = 0; i < 4; i++) { random.putScalar(1, i, 7); @@ -1326,36 +1383,36 @@ public class LayerOpValidation extends BaseOpValidation { } @Test() - public void exceptionThrown_WhenConv1DConfigInvalid() { - assertThrows(IllegalArgumentException.class,() -> { - int nIn = 3; - int nOut = 4; - int k = 2; - int mb = 3; - int img = 28; + public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) { + assertThrows(IllegalArgumentException.class,() -> { + int nIn = 3; + int nOut = 4; + int k = 2; + int mb = 3; + int img = 28; - SameDiff sd = SameDiff.create(); - INDArray wArr = Nd4j.create(k, nIn, nOut); - INDArray inArr = Nd4j.create(mb, nIn, img); + SameDiff sd = SameDiff.create(); + INDArray wArr = Nd4j.create(k, nIn, nOut); + INDArray inArr = Nd4j.create(mb, nIn, img); - SDVariable in = sd.var("in", inArr); - SDVariable w = sd.var("W", wArr); + SDVariable in = sd.var("in", inArr); + SDVariable w = sd.var("W", wArr); - SDVariable[] vars = new SDVariable[]{in, w}; + SDVariable[] vars = new SDVariable[]{in, w}; - Conv1DConfig conv1DConfig = Conv1DConfig.builder() - .k(k).p(-1).s(0) - .paddingMode(PaddingMode.VALID) - .build(); + Conv1DConfig conv1DConfig = Conv1DConfig.builder() + .k(k).p(-1).s(0) + .paddingMode(PaddingMode.VALID) + .build(); - SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); + SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); - }); + }); } @Test() - public void exceptionThrown_WhenConv2DConfigInvalid() { + public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { Nd4j.getRandom().setSeed(12345); @@ -1378,40 +1435,42 @@ public class LayerOpValidation extends BaseOpValidation { } @Test() - public void exceptionThrown_WhenConf3DInvalid() { - assertThrows(IllegalArgumentException.class,() -> { - Nd4j.getRandom().setSeed(12345); + public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) { + assertThrows(IllegalArgumentException.class,() -> { + Nd4j.getRandom().setSeed(12345); - //NCDHW format - int[] inSizeNCDHW = {2, 3, 4, 5, 5}; + //NCDHW format + int[] inSizeNCDHW = {2, 3, 4, 5, 5}; - List failed = new ArrayList<>(); + List failed = new ArrayList<>(); - for (boolean ncdhw : new boolean[]{true, false}) { - int nIn = inSizeNCDHW[1]; - int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW)); + for (boolean ncdhw : new boolean[]{true, false}) { + int nIn = inSizeNCDHW[1]; + int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW)); - SameDiff sd = SameDiff.create(); - SDVariable in = sd.var("in", shape); + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var("in", shape); - SDVariable out; - String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape); + SDVariable out; + String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape); - SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC] - SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10)); - out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder() - .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC) - .isSameMode(true) - .kH(2).kW(2).kD(2) - .sD(1).sH(1).sW(-1).dW(-1) - .build()); - } - }); + SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC] + SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10)); + out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder() + .dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC) + .isSameMode(true) + .kH(2).kW(2).kD(2) + .sD(1).sH(1).sW(-1).dW(-1) + .build()); + } + }); } @Test - public void testLayerNormMixedOrders() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLayerNormMixedOrders(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f'); INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); @@ -1458,7 +1517,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void testBiasAdd_nchw_nhwc() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBiasAdd_nchw_nhwc(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); for (boolean nchw : new boolean[]{true, false}) { @@ -1489,6 +1550,8 @@ public class LayerOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDepthwiseConv2D(){ int bS = 10; @@ -1527,7 +1590,9 @@ public class LayerOpValidation extends BaseOpValidation { @Test - public void LSTMLayerTestCase1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void LSTMLayerTestCase1(Nd4jBackend backend) { int bS = 5; int nIn = 3; @@ -1602,7 +1667,9 @@ public class LayerOpValidation extends BaseOpValidation { @Test - public void LSTMLayerTestCase2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void LSTMLayerTestCase2(Nd4jBackend backend) { int bS = 5; int nIn = 3; int numUnits = 7; @@ -1660,7 +1727,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void LSTMLayerTestCase3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void LSTMLayerTestCase3(Nd4jBackend backend) { int bS = 5; int nIn = 3; int numUnits = 7; @@ -1721,7 +1790,9 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - public void GRUTestCase() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void GRUTestCase(Nd4jBackend backend) { int bS = 5; int nIn = 4; int nOut = 6; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java index 7f5cf0884..dcf7d6971 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java @@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; @@ -43,9 +45,7 @@ import static org.junit.jupiter.api.Assertions.*; @Slf4j public class LossOpValidation extends BaseOpValidation { - public LossOpValidation(Nd4jBackend backend) { - super(backend); - } + @Override public long getTimeoutMilliseconds() { @@ -56,7 +56,9 @@ public class LossOpValidation extends BaseOpValidation { public static final Set NO_BP_YET = new HashSet<>(); @Test - public void testLoss2d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLoss2d(Nd4jBackend backend) { final List oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax"); Nd4j.getRandom().setSeed(12345); @@ -69,7 +71,7 @@ public class LossOpValidation extends BaseOpValidation { "absdiff", "cosine", "hinge", "huber", "log", "mse", "sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax" - }) { + }) { for(String weights : new String[]{"none", "scalar", "perExample", "perOutput"}) { @@ -368,6 +370,8 @@ public class LossOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCosineDistance(){ INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}}); INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}}); @@ -386,6 +390,8 @@ public class LossOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testL2Loss(){ for( int rank=0; rank<=3; rank++ ){ @@ -428,7 +434,9 @@ public class LossOpValidation extends BaseOpValidation { } @Test - public void testNonZeroResult() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNonZeroResult(Nd4jBackend backend) { INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5); INDArray w = Nd4j.scalar(1.0); INDArray label = Nd4j.rand(DataType.DOUBLE, 10, 5); @@ -486,6 +494,8 @@ public class LossOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void TestStdLossMixedDataType(){ // Default Data Type in this test suite is Double. // This test used to throw an Exception that we have mixed data types. diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 58c8f0825..0ca30d2ae 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -23,6 +23,8 @@ package org.nd4j.autodiff.opvalidation; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -78,13 +80,12 @@ import static org.junit.Assume.assumeNotNull; @Slf4j public class MiscOpValidation extends BaseOpValidation { - public MiscOpValidation(Nd4jBackend backend) { - super(backend); - } @Test - public void testGradientAutoBroadcast1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGradientAutoBroadcast1(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); @@ -171,7 +172,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testGradientAutoBroadcast2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGradientAutoBroadcast2(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); List failed = new ArrayList<>(); @@ -260,7 +263,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testGradientAutoBroadcast3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGradientAutoBroadcast3(Nd4jBackend backend) { //These tests: output size > input sizes Nd4j.getRandom().setSeed(12345); @@ -368,7 +373,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testScatterOpGradients() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterOpGradients(Nd4jBackend backend) { List failed = new ArrayList<>(); for (int i = 0; i < 7; i++) { @@ -470,6 +477,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testScatterUpdate(){ INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3); INDArray updates = Nd4j.create(new float[][]{ @@ -491,7 +500,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testGatherGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGatherGradient(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); List failed = new ArrayList<>(); @@ -542,6 +553,8 @@ public class MiscOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTrace(){ //TODO need to work out how to handle shape_op for scalars... //OpValidationSuite.ignoreFailing(); @@ -567,7 +580,9 @@ public class MiscOpValidation extends BaseOpValidation { @Test - public void testTensorGradTensorMmul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTensorGradTensorMmul(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); Nd4j.getRandom().setSeed(12345); @@ -589,7 +604,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testMulGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMulGradient(Nd4jBackend backend) { INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); @@ -654,22 +671,21 @@ public class MiscOpValidation extends BaseOpValidation { @Test - public void testMmulGradientManual() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmulGradientManual(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); Map inputs = new HashMap<>(); inputs.put("x", sumInput); inputs.put("y", sumInput.dup()); - sameDiff.defineFunction("mmulGradient", new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable input = sameDiff.var("x", inputs.get("x")); - SDVariable input2 = sameDiff.var("y", inputs.get("y")); - SDVariable exp = sameDiff.mmul(input, input2); - SDVariable sum = sameDiff.sum(exp, Integer.MAX_VALUE); - return new SDVariable[]{sum}; - } + sameDiff.defineFunction("mmulGradient", (sameDiff1, inputs1, variableInputs) -> { + SDVariable input = sameDiff1.var("x", inputs1.get("x")); + SDVariable input2 = sameDiff1.var("y", inputs1.get("y")); + SDVariable exp = sameDiff1.mmul(input, input2); + SDVariable sum = sameDiff1.sum(exp, Integer.MAX_VALUE); + return new SDVariable[]{sum}; }, inputs); @@ -698,6 +714,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMmulGradients(){ int[] aShape = new int[]{2,3}; int[] bShape = new int[]{3,4}; @@ -749,7 +767,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testBatchMmulBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBatchMmulBasic(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873 int M = 5; int N = 3; @@ -774,7 +794,9 @@ public class MiscOpValidation extends BaseOpValidation { @Test - public void testMmulWithTranspose() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmulWithTranspose(Nd4jBackend backend) { //Here: [x,3]^T * [x,4] = [3,4] @@ -811,6 +833,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMmulOutputSizeCalculation(){ //[3,2] x [2,4] with result transpose: output shape [4,3] INDArray a = Nd4j.create(3,2); @@ -820,7 +844,7 @@ public class MiscOpValidation extends BaseOpValidation { .transposeA(false) .transposeB(false) .transposeResult(true) - .build()); + .build()); val outShapes = Nd4j.getExecutioner().calculateOutputShape(m); assertArrayEquals(new long[]{4,3}, outShapes.get(0).getShape()); @@ -843,6 +867,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testFillOp(){ INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT); @@ -857,6 +883,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testClipByNorm(){ //Expected: if array.norm2(1) is less than 1.0, not modified //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() @@ -889,6 +917,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testClipByNorm2(){ //Expected: if array.norm2(1) is less than 1.0, not modified //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() @@ -932,6 +962,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testClipByNorm1(){ //Expected: if array.norm2(1) is less than 1.0, not modified //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() @@ -972,6 +1004,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testClipByNorm0(){ //Expected: if array.norm2(0) is less than 1.0, not modified //Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2() @@ -1001,6 +1035,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCumSum(){ List failing = new ArrayList<>(); @@ -1066,6 +1102,8 @@ public class MiscOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCumProd(){ List failing = new ArrayList<>(); @@ -1134,6 +1172,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testOneHot1(){ List failed = new ArrayList<>(); @@ -1164,6 +1204,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testOneHotOp(){ //https://www.tensorflow.org/api_docs/python/tf/one_hot //https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp @@ -1178,7 +1220,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testOneHot2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOneHot2(Nd4jBackend backend) { INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1); @@ -1198,7 +1242,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testOneHot4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOneHot4(Nd4jBackend backend) { INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1); @@ -1218,7 +1264,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testOneHot3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOneHot3(Nd4jBackend backend) { //https://github.com/deeplearning4j/deeplearning4j/issues/6872 //https://www.tensorflow.org/api_docs/python/tf/one_hot @@ -1253,6 +1301,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLinspace(){ SameDiff sd = SameDiff.create(); SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10); @@ -1266,6 +1316,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLinspace2(){ OpValidationSuite.ignoreFailing(); //TODO 2019/01/18 SameDiff sd = SameDiff.create(); @@ -1280,7 +1332,9 @@ public class MiscOpValidation extends BaseOpValidation { @Test - public void testShapeFn() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShapeFn(Nd4jBackend backend) { INDArray in = Nd4j.create(new long[]{1, 2}); @@ -1294,7 +1348,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testShapeFn2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShapeFn2(Nd4jBackend backend) { INDArray i = Nd4j.create(1,3); @@ -1307,6 +1363,8 @@ public class MiscOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMergeRank1(){ SameDiff sd = SameDiff.create(); SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5)); @@ -1325,7 +1383,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testDiagPart() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDiagPart(Nd4jBackend backend) { INDArray i = Nd4j.create(5,5); SameDiff sd = SameDiff.create(); @@ -1337,7 +1397,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testDiagShapeFn() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDiagShapeFn(Nd4jBackend backend) { INDArray i = Nd4j.create(5,5); CustomOp op = new DiagPart(i, null); @@ -1350,6 +1412,8 @@ public class MiscOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testZerosOnesLike(){ Nd4j.getRandom().setSeed(12345); @@ -1392,6 +1456,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testZerosLikeOp(){ INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0); @@ -1407,6 +1473,8 @@ public class MiscOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testConfusionMatrix(){ DataType dt = DataType.DOUBLE; @@ -1443,6 +1511,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIsNonDecreasingIsStrictlyIncr(){ List shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3}); @@ -1506,6 +1576,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testExtractImagePatches(){ /* tf.reset_default_graph() @@ -1553,6 +1625,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSegmentProdBpSimple(){ INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT); @@ -1573,6 +1647,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMmulRank4() throws Exception { Nd4j.getRandom().setSeed(12345); @@ -1608,6 +1684,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMmulRank4_simple(){ INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64); @@ -1634,6 +1712,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNthElementRank1(){ INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9}); INDArray n = Nd4j.scalar(0); @@ -1656,6 +1736,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTensorMmulShape(){ INDArray a = Nd4j.create(new double[]{2}).reshape(1); INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2); @@ -1674,6 +1756,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTensorMmulShape2(){ INDArray a = Nd4j.create(new double[]{2}).reshape(1); INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2); @@ -1682,6 +1766,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testStopGradient(){ SameDiff sd = SameDiff.create(); @@ -1701,6 +1787,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCheckNumerics(){ OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927 @@ -1744,7 +1832,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testCheckNumerics2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCheckNumerics2(Nd4jBackend backend) { INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4); INDArray msg = Nd4j.scalar("My error message!"); @@ -1757,6 +1847,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testHistogramFixedWidth(){ //Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf] INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9); @@ -1775,6 +1867,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDynamicPartition(){ INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); @@ -1793,6 +1887,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testListDiff(){ INDArray x = Nd4j.createFromArray(0, 1, 2, 3); INDArray y = Nd4j.createFromArray(3, 1); @@ -1812,7 +1908,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testDivideNoNan() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDivideNoNan(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff() SameDiff sameDiff = SameDiff.create(); @@ -1836,7 +1934,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testDigamma() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDigamma(Nd4jBackend backend) { INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); @@ -1851,7 +1951,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testFlatten() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlatten(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); @@ -1873,7 +1975,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testFusedBatchNorm() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFusedBatchNorm(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); SameDiff sameDiff = SameDiff.create(); @@ -1918,7 +2022,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testIgamma() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIgamma(Nd4jBackend backend) { INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); @@ -1934,7 +2040,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testIgammaC() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIgammaC(Nd4jBackend backend) { INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); @@ -1951,7 +2059,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testLgamma() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLgamma(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); @@ -1976,7 +2086,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testLu() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLu(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); @@ -2007,7 +2119,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testMatrixBandPart() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatrixBandPart(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); SameDiff sameDiff = SameDiff.create(); @@ -2037,7 +2151,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testPolygamma() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPolygamma(Nd4jBackend backend) { INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); @@ -2053,7 +2169,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testTriangularSolve() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTriangularSolve(Nd4jBackend backend) { INDArray a = Nd4j.createFromArray(new float[]{ 3.f, 0.f, 0.f, 0.f, @@ -2077,7 +2195,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testBiasAdd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBiasAdd(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); @@ -2106,7 +2226,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testBiasAddGrad() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBiasAddGrad(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); @@ -2126,7 +2248,9 @@ public class MiscOpValidation extends BaseOpValidation { } @Test - public void testRoll() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRoll(Nd4jBackend backend) { INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}). @@ -2146,6 +2270,8 @@ public class MiscOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSeqMask(){ INDArray arr = Nd4j.createFromArray(1,2,3); INDArray maxLen = Nd4j.scalar(4); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 8c6b62dcd..0715f94fc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -51,12 +53,11 @@ import static org.junit.jupiter.api.Assertions.*; @Slf4j public class RandomOpValidation extends BaseOpValidation { - public RandomOpValidation(Nd4jBackend backend) { - super(backend); - } @Test - public void testRandomOpsSDVarShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRandomOpsSDVarShape(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); List failed = new ArrayList<>(); @@ -157,7 +158,9 @@ public class RandomOpValidation extends BaseOpValidation { } @Test - public void testRandomOpsLongShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRandomOpsLongShape(Nd4jBackend backend) { List failed = new ArrayList<>(); for (long[] shape : Arrays.asList(new long[]{1000}, new long[]{100, 10}, new long[]{40, 5, 5})) { @@ -283,6 +286,8 @@ public class RandomOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRandomBinomial(){ INDArray z = Nd4j.create(new long[]{10}); @@ -293,7 +298,9 @@ public class RandomOpValidation extends BaseOpValidation { } @Test - public void testUniformRankSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUniformRankSimple(Nd4jBackend backend) { INDArray arr = Nd4j.createFromArray(new double[]{100.0}); // OpTestCase tc = new OpTestCase(DynamicCustomOp.builder("randomuniform") @@ -325,7 +332,9 @@ public class RandomOpValidation extends BaseOpValidation { @Test - public void testRandomExponential() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRandomExponential(Nd4jBackend backend) { long length = 1_000_000; INDArray shape = Nd4j.createFromArray(new double[]{length}); INDArray out = Nd4j.createUninitialized(new long[]{length}); @@ -347,6 +356,8 @@ public class RandomOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRange(){ //Technically deterministic, not random... @@ -380,6 +391,8 @@ public class RandomOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAllEmptyReduce(){ INDArray x = Nd4j.createFromArray(true, true, true); All all = new All(x); @@ -389,6 +402,8 @@ public class RandomOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testUniformDtype(){ Nd4j.getRandom().setSeed(12345); for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){ @@ -417,6 +432,8 @@ public class RandomOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRandomExponential2(){ Nd4j.getRandom().setSeed(12345); DynamicCustomOp op = DynamicCustomOp.builder("random_exponential") diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java index 72a0dcadf..34e8f7c37 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java @@ -24,6 +24,8 @@ import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.validation.OpTestCase; import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.linalg.api.buffer.DataType; @@ -51,10 +53,6 @@ public class ReductionBpOpValidation extends BaseOpValidation { private DataType initialType; - public ReductionBpOpValidation(Nd4jBackend backend) { - super(backend); - } - @BeforeEach public void before() { Nd4j.create(1); @@ -71,14 +69,16 @@ public class ReductionBpOpValidation extends BaseOpValidation { @AfterEach - public void tearDown() { + public void tearDown(Nd4jBackend backend) { NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); } @Test - public void testReduceSumBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduceSumBP(Nd4jBackend backend) { //Full array reduction //reduce_sum_bp op: has 2 inputs (original pre-reduce input, and gradient at output (epsilon)) @@ -104,7 +104,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testReduceSumAlongDim0BP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduceSumAlongDim0BP(Nd4jBackend backend) { //Reduction along dimension //Inputs/outputs as before - but note that the output is no longer a scalar @@ -130,7 +132,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testReduceSumAlongDim1BP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduceSumAlongDim1BP(Nd4jBackend backend) { //Reduction along dimension //Inputs/outputs as before - but note that the output is no longer a scalar @@ -158,7 +162,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { @Test - public void testMeanBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMeanBP(Nd4jBackend backend) { //dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j)) // = 1/N * dL/dOut @@ -189,7 +195,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testMeanBP_Rank1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMeanBP_Rank1(Nd4jBackend backend) { INDArray dLdOut = Nd4j.scalar(0.5); INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5 / 3); @@ -202,7 +210,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testMeanAlongDim0BP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMeanAlongDim0BP(Nd4jBackend backend) { //Reduction along dimension //Inputs/outputs as before - but note that the output is no longer a scalar @@ -230,7 +240,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testMeanAlongDim1BP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMeanAlongDim1BP(Nd4jBackend backend) { //Reduction along dimension //Inputs/outputs as before - but note that the output is no longer a scalar @@ -258,7 +270,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { @Test - public void testMinBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMinBP(Nd4jBackend backend) { //Full array min reduction //dL/dIn_i = dL/dOut * dOut/dIn_i @@ -297,7 +311,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testMinAlongDimensionBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMinAlongDimensionBP(Nd4jBackend backend) { //Full array min reduction //dL/dIn_i = dL/dOut * dOut/dIn_i @@ -340,7 +356,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testMaxBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMaxBP(Nd4jBackend backend) { //Full array max reduction //dL/dIn_i = dL/dOut * dOut/dIn_i @@ -370,7 +388,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testMaxAlongDimensionBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMaxAlongDimensionBP(Nd4jBackend backend) { //Full array min reduction //dL/dIn_i = dL/dOut * dOut/dIn_i @@ -413,7 +433,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testProdBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testProdBP(Nd4jBackend backend) { //Full array product reduction //dL/dIn_i = dL/dOut * dOut/dIn_i @@ -442,7 +464,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testProdAlongDimensionBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testProdAlongDimensionBP(Nd4jBackend backend) { //dL/dIn_i = dL/dOut * dOut/dIn_i // = dL/dOut * d(prod(in))/dIn_i // = dL/dOut * (prod(in) / in_i) @@ -498,7 +522,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testStdevBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStdevBP(Nd4jBackend backend) { //If out = stdev(in) then: //dL/dIn = dL/dOut * dOut/dIn //dOut/dIn_i = (in_i-mean)/(stdev * (n-1)) @@ -534,7 +560,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testStdevBP_Rank1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStdevBP_Rank1(Nd4jBackend backend) { INDArray dLdOut = Nd4j.scalar(0.5); INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); double stdev = preReduceInput.stdNumber(true).doubleValue(); @@ -555,7 +583,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testStdevAlongDimensionBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStdevAlongDimensionBP(Nd4jBackend backend) { //If out = stdev(in) then: //dL/dIn = dL/dOut * dOut/dIn //dOut/dIn_i = (in_i-mean)/(stdev * (n-1)) @@ -600,7 +630,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testVarianceBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVarianceBP(Nd4jBackend backend) { //If out = variance(in) then: //dL/dIn = dL/dOut * dOut/dIn //dOut/dIn_i = 2*(in_i-mean)/(n-1) @@ -636,7 +668,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testVarianceAlongDimensionBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVarianceAlongDimensionBP(Nd4jBackend backend) { //If out = variance(in) then: //dL/dIn = dL/dOut * dOut/dIn //dOut/dIn_i = 2*(in_i-mean)/(n-1) @@ -678,7 +712,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { @Test - public void testCumSumBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCumSumBP(Nd4jBackend backend) { //Standard case, non-reverse, non-exclusive //dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i // = sum_j dL/dOut_j * d(in_0 + ... + in_j)/dIn_i @@ -748,7 +784,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { @Test - public void testNorm2Bp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm2Bp(Nd4jBackend backend) { //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * x/|x|_2 @@ -775,7 +813,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testNorm2AlongDimensionBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm2AlongDimensionBP(Nd4jBackend backend) { //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * x/|x|_2 @@ -808,7 +848,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testNorm1Bp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm1Bp(Nd4jBackend backend) { //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * sgn(in) @@ -835,7 +877,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testNorm1AlongDimensionBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm1AlongDimensionBP(Nd4jBackend backend) { //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * sgn(in) @@ -867,7 +911,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testNormMaxBp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNormMaxBp(Nd4jBackend backend) { //out = max_i (|in_i|) //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise) @@ -897,7 +943,9 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testNormMaxAlongDimensionBP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNormMaxAlongDimensionBP(Nd4jBackend backend) { //out = max_i (|in_i|) //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 2f1cea2d7..6f7880a01 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -76,16 +77,13 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) + public class ReductionOpValidation extends BaseOpValidation { - - public ReductionOpValidation(Nd4jBackend backend) { - super(backend); - } - @Test - public void testStdev() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStdev(Nd4jBackend backend) { List errors = new ArrayList<>(); for (Pair p : NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE)) { @@ -111,7 +109,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testZeroCount() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testZeroCount(Nd4jBackend backend) { List allFailed = new ArrayList<>(); for (int i = 0; i < 21; i++) { SameDiff sd = SameDiff.create(); @@ -145,7 +145,9 @@ public class ReductionOpValidation extends BaseOpValidation { @Test - public void testZeroFraction() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testZeroFraction(Nd4jBackend backend) { List allFailed = new ArrayList<>(); for (int i = 0; i < 2; i++) { SameDiff sd = SameDiff.create(); @@ -175,7 +177,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testReductionGradientsSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReductionGradientsSimple(Nd4jBackend backend) { //OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES //Test reductions: final and only function Nd4j.getRandom().setSeed(12345); @@ -344,7 +348,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testReductionGradients1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReductionGradients1(Nd4jBackend backend) { //Test reductions: final, but *not* the only function Nd4j.getRandom().setSeed(12345); @@ -472,7 +478,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testReductionGradients2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReductionGradients2(Nd4jBackend backend) { //Test reductions: NON-final function Nd4j.getRandom().setSeed(12345); @@ -650,7 +658,9 @@ public class ReductionOpValidation extends BaseOpValidation { @Test - public void testReduce3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduce3(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int d0 = 3; @@ -755,7 +765,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testMoments() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMoments(Nd4jBackend backend) { for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) { INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4); @@ -787,9 +799,11 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testMomentsOp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMomentsOp(Nd4jBackend backend) { int[] axes = new int[]{0}; - INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4); + INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray outMean = Nd4j.createUninitialized(new long[]{4}); INDArray outVar = Nd4j.createUninitialized(new long[]{4}); @@ -804,7 +818,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testNormalizeMomentsOp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNormalizeMomentsOp(Nd4jBackend backend) { INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10); INDArray ssSum = data.sum(0); INDArray ssSqSum = data.mul(data).sum(0); @@ -824,7 +840,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testAllAny() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllAny(Nd4jBackend backend) { INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4); INDArray allOnes = Nd4j.ones(DataType.FLOAT, 3, 4); @@ -852,7 +870,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testIndexAccum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIndexAccum(Nd4jBackend backend) { List failed = new ArrayList<>(); List dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/); @@ -941,7 +961,9 @@ public class ReductionOpValidation extends BaseOpValidation { @Test - public void testReduce3_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduce3_2(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int d0 = 3; @@ -1039,7 +1061,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testReductionsBackwards() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReductionsBackwards(Nd4jBackend backend) { // for (int i = 0; i < 7; i++) { int i=5; { @@ -1108,6 +1132,8 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDotProductAttention(){ final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); @@ -1127,12 +1153,14 @@ public class ReductionOpValidation extends BaseOpValidation { t.norm1("out"); String err = OpValidation.validate(new TestCase(sd) - .expectedOutput("out", finalOut) - .gradientCheck(true)); + .expectedOutput("out", finalOut) + .gradientCheck(true)); assertNull(err); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDotProductAttentionWithMask(){ final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); @@ -1163,6 +1191,8 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDotProductAttentionMultiHeadInputWithMask(){ final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3}); @@ -1194,6 +1224,8 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDotProductAttentionMultiHeadInput(){ final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3}); final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3}); @@ -1221,6 +1253,8 @@ public class ReductionOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultiHeadedDotProductAttention(){ final INDArray k = Nd4j.rand(new int[]{10, 4, 5}); final INDArray v = Nd4j.rand(new int[]{10, 4, 5}); @@ -1272,6 +1306,8 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDotProductAttentionWeirdInputs(){ final INDArray keys = Nd4j.rand(new int[]{10, 4, 3}); final INDArray values = Nd4j.rand(new int[]{10, 4, 3}); @@ -1309,6 +1345,8 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultiHeadedDotProductAttentionWeirdInputs(){ final INDArray k = Nd4j.rand(new int[]{10, 4, 5}); final INDArray v = Nd4j.rand(new int[]{10, 4, 5}); @@ -1366,7 +1404,9 @@ public class ReductionOpValidation extends BaseOpValidation { } } @Test - public void testSufficientStatisticsOp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSufficientStatisticsOp(Nd4jBackend backend) { INDArray data = Nd4j.createFromArray(new double[]{ 5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1., 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5 @@ -1392,7 +1432,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testStandardDeviation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStandardDeviation(Nd4jBackend backend) { for (boolean keepDims : new boolean[]{false, true}) { SameDiff sameDiff = SameDiff.create(); @@ -1419,7 +1461,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testSquaredNorm() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSquaredNorm(Nd4jBackend backend) { for (boolean keepDims : new boolean[]{false, true}) { SameDiff sameDiff = SameDiff.create(); @@ -1442,7 +1486,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testShannonEntropy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShannonEntropy(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695 SameDiff sameDiff = SameDiff.create(); @@ -1462,7 +1508,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testEntropy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEntropy(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); @@ -1481,7 +1529,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testAMean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAMean(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); @@ -1502,7 +1552,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testMean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMean(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); @@ -1523,7 +1575,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testNorm1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm1(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); @@ -1544,7 +1598,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testNorm2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm2(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); @@ -1565,7 +1621,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testNormMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNormMax(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); @@ -1586,7 +1644,9 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - public void testSoftmaxCrossEntropyWithLogitsLoss() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftmaxCrossEntropyWithLogitsLoss(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); SameDiff sameDiff = SameDiff.create(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java index 53ea7d095..3a4ef608e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java @@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; @@ -43,12 +45,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class RnnOpValidation extends BaseOpValidation { - public RnnOpValidation(Nd4jBackend backend) { - super(backend); - } @Test - public void testRnnBlockCell(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRnnBlockCell(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int mb = 2; int nIn = 3; @@ -147,7 +148,9 @@ public class RnnOpValidation extends BaseOpValidation { @Test - public void testRnnBlockCellManualTFCompare() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRnnBlockCellManualTFCompare(Nd4jBackend backend) { //Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS" SameDiff sd = SameDiff.create(); @@ -209,6 +212,8 @@ public class RnnOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGRUCell(){ Nd4j.getRandom().setSeed(12345); int mb = 2; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 46e03f3e3..38080a906 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -28,6 +28,8 @@ import lombok.val; import org.apache.commons.math3.linear.LUDecomposition; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -67,9 +69,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*; @Slf4j public class ShapeOpValidation extends BaseOpValidation { - public ShapeOpValidation(Nd4jBackend backend) { - super(backend); - } /* To test: @@ -83,7 +82,9 @@ public class ShapeOpValidation extends BaseOpValidation { */ @Test - public void testConcat() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat(Nd4jBackend backend) { // int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2}; int[] concatDim = new int[]{0, 0, 0}; List> origShapes = new ArrayList<>(); @@ -123,7 +124,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testReshapeGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshapeGradient(Nd4jBackend backend) { //https://github.com/deeplearning4j/deeplearning4j/issues/6873 int[] origShape = new int[]{3, 4, 5}; @@ -159,7 +162,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testPermuteGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPermuteGradient(Nd4jBackend backend) { int[] origShape = new int[]{3, 4, 5}; List failed = new ArrayList<>(); @@ -197,6 +202,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRank(){ List inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5}); @@ -224,7 +231,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testExpandDimsGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExpandDimsGradient(Nd4jBackend backend) { val origShape = new long[]{3, 4}; List failed = new ArrayList<>(); @@ -280,7 +289,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testSqueezeGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSqueezeGradient(Nd4jBackend backend) { val origShape = new long[]{3, 4, 5}; List failed = new ArrayList<>(); @@ -344,7 +355,9 @@ public class ShapeOpValidation extends BaseOpValidation { @Test - public void testSliceGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSliceGradient(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); //Order here: original shape, begin, size @@ -434,7 +447,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testStridedSliceGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedSliceGradient(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); //Order here: original shape, begin, size @@ -497,7 +512,9 @@ public class ShapeOpValidation extends BaseOpValidation { @Test - public void testMerge() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMerge(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); List failed = new ArrayList<>(); @@ -573,7 +590,7 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test() - public void testStack() { + public void testStack(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); List failed = new ArrayList<>(); @@ -664,7 +681,9 @@ public class ShapeOpValidation extends BaseOpValidation { @Test - public void testUnStack() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnStack(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); List failed = new ArrayList<>(); @@ -752,7 +771,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testTile() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTile(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); List tileArg = Arrays.asList( @@ -824,6 +845,8 @@ public class ShapeOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTileBp(){ Nd4j.getRandom().setSeed(12345); @@ -857,6 +880,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTileBp2(){ Nd4j.getRandom().setSeed(12345); @@ -891,7 +916,9 @@ public class ShapeOpValidation extends BaseOpValidation { @Test - public void testReshape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshape(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4); SDVariable x = sameDiff.var("x", arr); @@ -907,7 +934,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testReshape2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshape2(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int[] origShape = new int[]{3, 4, 5}; @@ -930,7 +959,9 @@ public class ShapeOpValidation extends BaseOpValidation { @Test - public void testTranspose() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTranspose(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4); SDVariable x = sameDiff.var("x", arr); @@ -942,6 +973,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTransposeOp(){ INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3); @@ -955,7 +988,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShape(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); val shape = new long[]{2, 3}; SDVariable x = sameDiff.var("x", shape); @@ -970,7 +1005,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testSize() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSize(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); val shape = new long[]{2, 3}; SDVariable x = sameDiff.var("x", DataType.FLOAT, shape); @@ -984,7 +1021,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testDiagShapeFn() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDiagShapeFn(Nd4jBackend backend) { INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4); OpTestCase op = new OpTestCase(new DiagPart(i, null)); @@ -998,6 +1037,8 @@ public class ShapeOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPermute(){ INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5); INDArray exp = in.permute(0,1,2); //No op @@ -1012,6 +1053,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPermute2(){ for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) { INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5); @@ -1032,6 +1075,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testConstant(){ //OpValidationSuite.ignoreFailing(); @@ -1059,6 +1104,8 @@ public class ShapeOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testUnstackEdgeCase2(){ for( int i=0; i<3; i++ ) { @@ -1073,7 +1120,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void invertPermutation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void invertPermutation(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray ia = Nd4j.create(new float[] {3, 4, 0, 2, 1}).castTo(DataType.INT); @@ -1090,6 +1139,8 @@ public class ShapeOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGatherNd(){ List indices = new ArrayList<>(); @@ -1128,7 +1179,9 @@ public class ShapeOpValidation extends BaseOpValidation { @Test - public void testReverseSequence() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverseSequence(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); float[] input_data = new float[]{ 1, 2, 3, @@ -1174,6 +1227,8 @@ public class ShapeOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMatrixDeterminant(){ OpValidationSuite.ignoreFailing(); //Gradient check failing @@ -1195,6 +1250,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDeterminant22(){ OpValidationSuite.ignoreFailing(); //Gradient check failing @@ -1219,6 +1276,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMatrixDeterminant3(){ OpValidationSuite.ignoreFailing(); //Gradient checks failing Nd4j.getRandom().setSeed(12345); @@ -1250,6 +1309,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMatrixDeterminant4(){ OpValidationSuite.ignoreFailing(); //Gradient checks failing Nd4j.getRandom().setSeed(12345); @@ -1270,6 +1331,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSegmentOps(){ OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6952 @@ -1362,6 +1425,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSegmentMean(){ INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3); INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2); @@ -1382,7 +1447,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testSequenceMask() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSequenceMask(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2}); // arr is not trainable, so it's constant in model @@ -1391,10 +1458,10 @@ public class ShapeOpValidation extends BaseOpValidation { // Test with static max len int maxlen = 5; INDArray expected = Nd4j.create(new float[] { - 1.f, 0.f, 0.f, 0.f, 0.f, - 1.f, 1.f, 1.f, 0.f, 0.f, - 1.f, 1.f, 0.f, 0.f, 0.f - }).reshape(3,5); + 1.f, 0.f, 0.f, 0.f, 0.f, + 1.f, 1.f, 1.f, 0.f, 0.f, + 1.f, 1.f, 0.f, 0.f, 0.f + }).reshape(3,5); INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.FLOAT)); SDVariable result1 = sameDiff.sequenceMask(lengths, maxlen, DataType.FLOAT); assertArrayEquals(expected.shape(), result1.eval().shape()); @@ -1416,6 +1483,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMeshGrid(){ List failed = new ArrayList<>(); @@ -1472,6 +1541,8 @@ public class ShapeOpValidation extends BaseOpValidation { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGather(){ List inArrs = new ArrayList<>(); List axis = new ArrayList<>(); @@ -1541,7 +1612,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testGatherSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGatherSimple(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2}); SDVariable x = sameDiff.var("x", arr); @@ -1551,7 +1624,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testGatherNdSingle() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGatherNdSingle(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4); INDArray arr2 = Nd4j.create(new float[]{1, 2, 3, 0, 1, 3, 1, 0, 2}, new long[]{3, 3}).castTo(DataType.INT); @@ -1563,14 +1638,16 @@ public class ShapeOpValidation extends BaseOpValidation { for (int i=0; i<3; i++){ INDArray idx = arr2.get(point(i), NDArrayIndex.all()); expected.putScalar(i, arr1.get(point(idx.getInt(0)), - point(idx.getInt(1)), - point(idx.getInt(2))).getDouble(0)); + point(idx.getInt(1)), + point(idx.getInt(2))).getDouble(0)); } assertEquals(expected, result.eval()); } @Test - public void testStack2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStack2(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2); INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2); @@ -1581,7 +1658,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testParallelStack() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testParallelStack(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2); INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2); @@ -1593,7 +1672,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testUnStack2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnStack2(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr1 = Nd4j.zeros(3, 2); INDArray arr2 = Nd4j.ones(3, 2); @@ -1606,7 +1687,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testPermuteSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPermuteSimple(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3)); SDVariable x = sameDiff.var("x", arr); @@ -1617,7 +1700,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testConcat2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat2(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4); INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(4, 8, 4)).reshape(1,4); @@ -1628,7 +1713,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testTile2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTile2(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4)); SDVariable x = sameDiff.var("x", arr); @@ -1641,7 +1728,9 @@ public class ShapeOpValidation extends BaseOpValidation { @Test - public void testSlice2d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSlice2d(Nd4jBackend backend) { INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); SameDiff sd = SameDiff.create(); @@ -1657,7 +1746,9 @@ public class ShapeOpValidation extends BaseOpValidation { @Test - public void testSlice3d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSlice3d(Nd4jBackend backend) { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); @@ -1672,7 +1763,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testStridedSlice2dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedSlice2dBasic(Nd4jBackend backend) { INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); SameDiff sd = SameDiff.create(); @@ -1690,7 +1783,9 @@ public class ShapeOpValidation extends BaseOpValidation { @Test - public void testStridedSliceBeginEndMask() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedSliceBeginEndMask(Nd4jBackend backend) { INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); SameDiff sd = SameDiff.create(); @@ -1705,7 +1800,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testStridedSliceEllipsisMask() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedSliceEllipsisMask(Nd4jBackend backend) { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); @@ -1722,7 +1819,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testStridedSliceNewAxisMask() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedSliceNewAxisMask(Nd4jBackend backend) { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); @@ -1735,7 +1834,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testStridedSliceNewAxisMask2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedSliceNewAxisMask2(Nd4jBackend backend) { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); @@ -1746,7 +1847,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testStridedSliceShrinkAxisMask() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedSliceShrinkAxisMask(Nd4jBackend backend) { INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5); SameDiff sd = SameDiff.create(); @@ -1763,7 +1866,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testSizeAt_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSizeAt_1(Nd4jBackend backend) { val array = Nd4j.create(10, 20, 30); val exp = Nd4j.scalar(DataType.LONG, 20); @@ -1777,6 +1882,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEye(){ int[] rows = new int[]{3,3,3,3}; int[] cols = new int[]{3,2,2,2}; @@ -1815,6 +1922,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSplit1(){ INDArray in = Nd4j.linspace(1,10,10).reshape(10); INDArray axis = Nd4j.scalar(-1); @@ -1833,6 +1942,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSplit2(){ INDArray in = Nd4j.linspace(1,24,24).reshape(3,8); INDArray axis = Nd4j.scalar(-1); @@ -1851,6 +1962,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDistancesExec(){ //https://github.com/deeplearning4j/deeplearning4j/issues/7001 for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) { @@ -1906,6 +2019,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testReductionShape(){ INDArray shape = Nd4j.createFromArray(4,2); @@ -1924,6 +2039,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void gatherTest(){ INDArray in = Nd4j.createFromArray(new double[][]{ {1,2,3,4,5}, @@ -1943,6 +2060,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSliceShape(){ INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT); @@ -1964,6 +2083,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWhereAllFalse(){ INDArray in = Nd4j.create(DataType.BOOL, 1917); DynamicCustomOp op = DynamicCustomOp.builder("Where") @@ -1978,6 +2099,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGatherScalar(){ INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100); INDArray indices = Nd4j.scalar(0); @@ -2002,6 +2125,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCastEmpty(){ INDArray emptyLong = Nd4j.empty(DataType.LONG); int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h @@ -2018,6 +2143,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGatherEmpty(){ /* tf.reset_default_graph() @@ -2050,6 +2177,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSplitEmpty(){ /* tf.reset_default_graph() @@ -2087,6 +2216,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testConcatEmpty(){ /* TF behaviour with concatenatioun of empty arrays: @@ -2136,6 +2267,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testConcatEmpty2(){ INDArray empty10a = Nd4j.create(DataType.INT, 1, 0); INDArray empty10b = Nd4j.create(DataType.INT, 1, 0); @@ -2168,6 +2301,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEmptyGather(){ /* tf.reset_default_graph() @@ -2200,6 +2335,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBroadcastDynamicShape1(){ //Test case: [2,1] and [4]: expect [2,4] @@ -2221,6 +2358,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBroadcastDynamicShape2(){ //Test case: [2,1,4] and [2,2,4]: expect [2,2,4] @@ -2243,6 +2382,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testStridedSliceShrinkAxis(){ INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2); INDArray begin = Nd4j.createFromArray(2); @@ -2268,6 +2409,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testStridedSliceEmpty(){ INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask! @@ -2290,6 +2433,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testStridedSliceEdgeCase(){ INDArray in = Nd4j.scalar(10).reshape(1); //Int [1] INDArray begin = Nd4j.ones(DataType.INT, 1); @@ -2315,6 +2460,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEmptySlice1(){ INDArray in = Nd4j.createFromArray(38); INDArray begin = Nd4j.createFromArray(1); @@ -2334,6 +2481,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEmptySlice2(){ INDArray in = Nd4j.createFromArray(38); INDArray begin = Nd4j.createFromArray(0); @@ -2353,6 +2502,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testFill(){ INDArray shape = Nd4j.createFromArray(0,4); @@ -2372,6 +2523,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testFill2(){ INDArray shape = Nd4j.createFromArray(0,4); @@ -2389,6 +2542,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPermuteShapeDynamicAxis(){ DynamicCustomOp op = DynamicCustomOp.builder("permute") @@ -2418,6 +2573,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGather2(){ SameDiff sd = SameDiff.create(); SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3)); @@ -2437,6 +2594,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPermute3(){ INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); INDArray permute = Nd4j.createFromArray(1,0); @@ -2455,6 +2614,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPermute4(){ INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); INDArray permute = Nd4j.createFromArray(1,0); @@ -2485,6 +2646,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testInvertPermutation(){ DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation") .addInputs(Nd4j.createFromArray(1, 0)) @@ -2492,7 +2655,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testBroadcastInt1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastInt1(Nd4jBackend backend) { INDArray out = Nd4j.create(DataType.INT, 1); DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape") @@ -2505,6 +2670,8 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBroadcastInt2(){ INDArray out = Nd4j.create(DataType.INT, 2); DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape") @@ -2544,7 +2711,9 @@ public class ShapeOpValidation extends BaseOpValidation { @Test - public void testMergeMaxIndex() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergeMaxIndex(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -2561,7 +2730,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testTriOp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTriOp(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable out = new Tri(sd, DataType.INT32, 3, 5, 2).outputVariable(); @@ -2573,7 +2744,9 @@ public class ShapeOpValidation extends BaseOpValidation { } @Test - public void testTriuOp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTriuOp(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {7,8,9},{10,11,12}})); @@ -2581,8 +2754,8 @@ public class ShapeOpValidation extends BaseOpValidation { out.markAsLoss(); INDArray expected = Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {0,8,9},{0,0,12}}); String err = OpValidation.validate(new TestCase(sd) - .expectedOutput("triu", expected) - .gradientCheck(true)); + .expectedOutput("triu", expected) + .gradientCheck(true)); assertNull(err); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 70e263740..c1464063c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; @@ -94,9 +96,6 @@ public class TransformOpValidation extends BaseOpValidation { private DataType initialType; - public TransformOpValidation(Nd4jBackend backend) { - super(backend); - } @BeforeEach public void before() { @@ -120,7 +119,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testScalarOps() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarOps(Nd4jBackend backend) { int d0 = 2; int d1 = 3; int d2 = 4; @@ -217,7 +218,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testScalarMulCF() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarMulCF(Nd4jBackend backend) { INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); INDArray outC = Nd4j.createUninitialized(3, 4); @@ -231,7 +234,9 @@ public class TransformOpValidation extends BaseOpValidation { @Test - public void testScalarMulCF2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarMulCF2(Nd4jBackend backend) { INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); @@ -242,7 +247,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testCross() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCross(Nd4jBackend backend) { INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3}); INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3}); @@ -270,7 +277,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testSpaceToDepth() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSpaceToDepth(Nd4jBackend backend) { Nd4j.getRandom().setSeed(1337); int miniBatch = 128; @@ -298,7 +307,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testDepthToSpace() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDepthToSpace(Nd4jBackend backend) { Nd4j.getRandom().setSeed(1337); int miniBatch = 128; @@ -325,7 +336,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testBatchToSpace() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBatchToSpace(Nd4jBackend backend) { //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863 Nd4j.getRandom().setSeed(1337); @@ -362,7 +375,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testSpaceToBatch() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSpaceToBatch(Nd4jBackend backend) { //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863 Nd4j.getRandom().setSeed(7331); @@ -400,7 +415,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testDynamicPartition() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDynamicPartition(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray ia = Nd4j.create(new double[]{4, 3, 5, 7, 8, 0}); @@ -440,7 +457,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testDynamicPartition2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDynamicPartition2(Nd4jBackend backend) { INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") @@ -458,7 +477,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testDynamicStitch() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDynamicStitch(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray ia = Nd4j.create(new double[]{5, 1, 3}, new long[]{3}); @@ -495,7 +516,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testDiag() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDiag(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray ia = Nd4j.create(new double[]{1, 2}, new int[]{2}); @@ -521,7 +544,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testDiagPart() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDiagPart(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray input = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4); @@ -540,7 +565,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testEye() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEye(Nd4jBackend backend) { int[] rows = new int[]{3, 3, 3, 3}; int[] cols = new int[]{3, 2, 2, 2}; int[][] batch = new int[][]{{}, {}, {4}, {3, 3}}; @@ -574,7 +601,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testEyeShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEyeShape(Nd4jBackend backend) { DynamicCustomOp dco = DynamicCustomOp.builder("eye") .addIntegerArguments(3, 3) //.addIntegerArguments(-99,3,3) //Also fails @@ -586,7 +615,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testTransforms() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTransforms(Nd4jBackend backend) { //Test transforms (non-pairwise) Nd4j.getRandom().setSeed(12345); @@ -1074,7 +1105,9 @@ public class TransformOpValidation extends BaseOpValidation { @Test - public void testPairwiseTransforms() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPairwiseTransforms(Nd4jBackend backend) { /* add, sub, mul, div, rsub, rdiv eq, neq, gt, lt, gte, lte, or, and, xor @@ -1258,7 +1291,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testIsX() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsX(Nd4jBackend backend) { List failed = new ArrayList<>(); for (int i = 0; i < 4; i++) { @@ -1313,7 +1348,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testReplaceWhereScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReplaceWhereScalar(Nd4jBackend backend) { for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { log.info("Testing condition: " + c.getClass().getSimpleName()); @@ -1335,7 +1372,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testReplaceWhereArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReplaceWhereArray(Nd4jBackend backend) { for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { INDArray inArr = Nd4j.rand(3, 4); @@ -1358,7 +1397,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testLogGrad() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogGrad(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE)); SDVariable log = sameDiff.math().log(input); @@ -1369,7 +1410,9 @@ public class TransformOpValidation extends BaseOpValidation { @Test - public void testSigmoidBackwards() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSigmoidBackwards(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); Map inputs = new HashMap<>(); @@ -1386,8 +1429,10 @@ public class TransformOpValidation extends BaseOpValidation { } -/* @Test - public void testDepth() { +/* @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDepth(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); SDVariable x = sameDiff.one("one",new long[]{2,2}); assertEquals(0,x.depth()); @@ -1396,7 +1441,9 @@ public class TransformOpValidation extends BaseOpValidation { }*/ @Test - public void testRank0EdgeCase() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRank0EdgeCase(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))); double d0 = v1.eval().getDouble(0); @@ -1409,7 +1456,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testAtan2BroadcastShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAtan2BroadcastShape(Nd4jBackend backend) { INDArray arr1 = Nd4j.create(new long[]{3, 1, 4}); INDArray arr2 = Nd4j.create(new long[]{1, 2, 4}); @@ -1424,7 +1473,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testBooleanAnd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBooleanAnd(Nd4jBackend backend) { Nd4j.setDataType(DataType.FLOAT); INDArray arr1 = Nd4j.create(new long[]{3, 4}); INDArray arr2 = Nd4j.create(new long[]{3, 4}); @@ -1438,7 +1489,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testScatterOpsScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterOpsScalar(Nd4jBackend backend) { for (String s : new String[]{"add", "sub", "mul", "div"}) { INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3); INDArray indices = Nd4j.scalar(5); @@ -1483,7 +1536,9 @@ public class TransformOpValidation extends BaseOpValidation { @Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540") @Test - public void testPad() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPad(Nd4jBackend backend) { INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0); INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG); INDArray value = Nd4j.scalar(10.0); @@ -1510,7 +1565,9 @@ public class TransformOpValidation extends BaseOpValidation { @Test - public void testMirrorPad() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMirrorPad(Nd4jBackend backend) { INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); @@ -1543,7 +1600,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMirrorPad2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMirrorPad2(Nd4jBackend backend) { INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); @@ -1569,7 +1628,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMirrorPadSymmetric() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMirrorPadSymmetric(Nd4jBackend backend) { INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4); INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT); @@ -1596,7 +1657,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testUnique() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnique(Nd4jBackend backend) { INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4}); INDArray expUnique = Nd4j.create(new double[]{3, 4, 1, 0, 2}); @@ -1618,7 +1681,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testTopK() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTopK(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); //Can't assume sorted here INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8}); @@ -1647,7 +1712,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testTopK1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTopK1(Nd4jBackend backend) { INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); INDArray k = Nd4j.scalar(1); INDArray outValue = Nd4j.create(DataType.DOUBLE, 1); @@ -1668,7 +1735,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testInTopK() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInTopK(Nd4jBackend backend) { for (int k = 4; k >= 1; k--) { log.info("Testing: k=" + k); INDArray in = Nd4j.linspace(1, 20, 20, DataType.DOUBLE).reshape(4, 5); @@ -1709,7 +1778,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testZeta() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testZeta(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182 INDArray x = Nd4j.rand(3, 4).addi(1.0); INDArray q = Nd4j.rand(3, 4); @@ -1726,7 +1797,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMaxEmptyScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMaxEmptyScalar(Nd4jBackend backend) { INDArray empty = Nd4j.empty(DataType.FLOAT); INDArray scalar = Nd4j.scalar(1.0f); @@ -1743,7 +1816,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testBroadcastEmpty() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastEmpty(Nd4jBackend backend) { // Nd4j.getExecutioner().enableVerboseMode(true); // Nd4j.getExecutioner().enableDebugMode(true); //Check broadcast behaviour with empty arrays. The idea is to match TF import behaviour, for import @@ -1833,7 +1908,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testStandardize() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStandardize(Nd4jBackend backend) { final INDArray random = Nd4j.rand(new int[]{10, 4}); final int[] axis = new int[]{1}; @@ -1854,7 +1931,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testStandardizeOP() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStandardizeOP(Nd4jBackend backend) { final INDArray random = Nd4j.rand(new int[]{10, 4}); final int[] axis = new int[]{1}; @@ -1869,7 +1948,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testStandardizeNoDeviation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStandardizeNoDeviation(Nd4jBackend backend) { final INDArray random = Nd4j.rand(new int[]{10, 4}); for (int i = 0; i < 4; i++) { random.putScalar(1, i, 7); @@ -1895,7 +1976,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMatMulTensor() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatMulTensor(Nd4jBackend backend) { final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5}); final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6}); @@ -1915,7 +1998,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMatMulTensorTranspose() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatMulTensorTranspose(Nd4jBackend backend) { for (boolean transposeA : new boolean[]{false, true}) { for (boolean transposeB : new boolean[]{false, true}) { for (boolean transposeResult : new boolean[]{false, true}) { @@ -2008,7 +2093,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testSoftmaxCF() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftmaxCF(Nd4jBackend backend) { INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5); INDArray arrF = arrC.dup('f'); @@ -2029,7 +2116,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testLogSumExp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogSumExp(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4); SameDiff sd = SameDiff.create(); @@ -2044,7 +2133,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testLogSumExp2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogSumExp2(Nd4jBackend backend) { for (int dim = 0; dim <= 2; dim++) { Nd4j.getRandom().setSeed(12345); @@ -2065,7 +2156,9 @@ public class TransformOpValidation extends BaseOpValidation { @Test - public void testCRELU() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCRELU(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2); @@ -2084,7 +2177,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testClipByAvgNorm() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testClipByAvgNorm(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2, 2); @@ -2105,7 +2200,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testEmbeddingLookup() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmbeddingLookup(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -2118,49 +2215,53 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testImageResize() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testImageResize(Nd4jBackend backend) { //TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea for (ImageResizeMethod method : ImageResizeMethod.values()) { - if (method==ImageResizeMethod.ResizeLanczos5 || method==ImageResizeMethod.ResizeArea || method == ImageResizeMethod.ResizeMitchelcubic) - {continue;} + if (method==ImageResizeMethod.ResizeLanczos5 || method==ImageResizeMethod.ResizeArea || method == ImageResizeMethod.ResizeMitchelcubic) + {continue;} - log.info("Trying {}", method); + log.info("Trying {}", method); - Nd4j.getRandom().setSeed(12345); - SameDiff sd = SameDiff.create(); - boolean preserveAspectRatio = true; - boolean antialias = true; - SDVariable inputImage = sd.var(Nd4j.rand(DataType.FLOAT, 1, 5, 5, 3)); - // NHWC format - long[] expectedShape = new long[]{1, 3, 3, 3}; - SDVariable requestedSize = sd.constant(Nd4j.createFromArray( new long[]{3, 3})); + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + boolean preserveAspectRatio = true; + boolean antialias = true; + SDVariable inputImage = sd.var(Nd4j.rand(DataType.FLOAT, 1, 5, 5, 3)); + // NHWC format + long[] expectedShape = new long[]{1, 3, 3, 3}; + SDVariable requestedSize = sd.constant(Nd4j.createFromArray( new long[]{3, 3})); - Function checkFunction = in -> { - boolean shapeOk = Arrays.equals(expectedShape, in.shape()); - if (shapeOk) return null; - return "Failed: shape differs - expected " + Arrays.toString(expectedShape) + " vs " + Arrays.toString(in.shape()) + " on method " + method; - }; + Function checkFunction = in -> { + boolean shapeOk = Arrays.equals(expectedShape, in.shape()); + if (shapeOk) return null; + return "Failed: shape differs - expected " + Arrays.toString(expectedShape) + " vs " + Arrays.toString(in.shape()) + " on method " + method; + }; - SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable().std(true); + SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable().std(true); - String err = OpValidation.validate(new TestCase(sd) - .gradientCheck(false) - .expected("image_resize", checkFunction)); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(false) + .expected("image_resize", checkFunction)); assertNull(err); } - } + } @Test - public void testMaximumBp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMaximumBp(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -2177,7 +2278,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMergeAddBp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergeAddBp(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -2194,7 +2297,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMergeMaxBp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergeMaxBp(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -2212,7 +2317,9 @@ public class TransformOpValidation extends BaseOpValidation { @Test - public void testMergeAvgBp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergeAvgBp(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -2229,7 +2336,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testReverseBp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverseBp(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -2243,7 +2352,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testUpsampling3dBp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUpsampling3dBp(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); for (boolean dataformat : new boolean[]{true, false}) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java index 108fd4ab6..19e16b00e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java @@ -24,8 +24,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; -import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; @@ -36,11 +37,8 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.factory.Nd4jBackend; -public class ConvConfigTests extends BaseNd4jTest { +public class ConvConfigTests extends BaseNd4jTestWithBackends { - public ConvConfigTests(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -48,7 +46,9 @@ public class ConvConfigTests extends BaseNd4jTest { } @Test - public void testDeConv2D(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDeConv2D(Nd4jBackend backend){ DeConv2DConfig.builder().kH(2).kW(4).build(); try{ @@ -108,8 +108,10 @@ public class ConvConfigTests extends BaseNd4jTest { } } - @Test - public void testConv2D(){ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv2D(Nd4jBackend backend){ Conv2DConfig.builder().kH(2).kW(4).build(); try{ @@ -169,8 +171,10 @@ public class ConvConfigTests extends BaseNd4jTest { } } - @Test - public void testPooling2D(){ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling2D(Nd4jBackend backend){ Pooling2DConfig.builder().kH(2).kW(4).build(); try{ @@ -230,8 +234,10 @@ public class ConvConfigTests extends BaseNd4jTest { } } - @Test - public void testDeConv3D(){ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDeConv3D(Nd4jBackend backend){ DeConv3DConfig.builder().kH(2).kW(4).kD(3).build(); try{ @@ -319,8 +325,10 @@ public class ConvConfigTests extends BaseNd4jTest { } } - @Test - public void testConv3D(){ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv3D(Nd4jBackend backend){ Conv3DConfig.builder().kH(2).kW(4).kD(3).build(); try{ @@ -410,8 +418,10 @@ public class ConvConfigTests extends BaseNd4jTest { - @Test - public void testPooling3D(){ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling3D(Nd4jBackend backend){ Pooling3DConfig.builder().kH(2).kW(4).kD(3).build(); try{ @@ -499,7 +509,9 @@ public class ConvConfigTests extends BaseNd4jTest { } } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testConv1D(){ Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java index e1473414a..9b3c3c2e9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java @@ -23,8 +23,10 @@ package org.nd4j.autodiff.samediff; import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.OpValidationSuite; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; @Disabled("AB 2019/05/21 - JVM Crash on ppc64 - Issue #7657") -public class FailingSameDiffTests extends BaseNd4jTest { +public class FailingSameDiffTests extends BaseNd4jTestWithBackends { - public FailingSameDiffTests(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -52,7 +51,9 @@ public class FailingSameDiffTests extends BaseNd4jTest { } @Test - public void testEye(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEye(Nd4jBackend backend){ //OpValidationSuite.ignoreFailing(); INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3}); List stack = new ArrayList<>(); @@ -68,7 +69,9 @@ public class FailingSameDiffTests extends BaseNd4jTest { } @Test - public void testEyeShape(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEyeShape(Nd4jBackend backend){ val dco = DynamicCustomOp.builder("eye") .addIntegerArguments(3,3) //.addIntegerArguments(-99,3,3) //Also fails @@ -80,7 +83,9 @@ public class FailingSameDiffTests extends BaseNd4jTest { } @Test - public void testExecutionDifferentShapesTransform(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExecutionDifferentShapesTransform(Nd4jBackend backend){ OpValidationSuite.ignoreFailing(); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4)); @@ -101,7 +106,9 @@ public class FailingSameDiffTests extends BaseNd4jTest { } @Test - public void testDropout() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDropout(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); SameDiff sd = SameDiff.create(); double p = 0.5; @@ -114,7 +121,9 @@ public class FailingSameDiffTests extends BaseNd4jTest { } @Test - public void testExecutionDifferentShapesDynamicCustom(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExecutionDifferentShapesDynamicCustom(Nd4jBackend backend){ OpValidationSuite.ignoreFailing(); SameDiff sd = SameDiff.create(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java index de2249359..2a2d11ef2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java @@ -26,13 +26,15 @@ import org.apache.commons.io.IOUtils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.graph.FlatConfiguration; import org.nd4j.graph.FlatGraph; import org.nd4j.graph.FlatNode; import org.nd4j.graph.FlatVariable; import org.nd4j.graph.IntPair; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; @@ -70,11 +72,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -public class FlatBufferSerdeTest extends BaseNd4jTest { +public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends { - public FlatBufferSerdeTest(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -84,7 +83,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { @Test - public void testBasic(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception { SameDiff sd = SameDiff.create(); INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4); SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() ); @@ -139,7 +140,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { } @Test - public void testSimple(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception { for( int i = 0; i < 10; i++ ) { for(boolean execFirst : new boolean[]{false, true}) { log.info("Starting test: i={}, execFirst={}", i, execFirst); @@ -268,7 +271,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { @Test - public void testTrainingSerde(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception { //Ensure 2 things: //1. Training config is serialized/deserialized correctly @@ -352,7 +357,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { @Test - public void pooling3DSerialization(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void pooling3DSerialization(Nd4jBackend backend){ SameDiff sd = SameDiff.create(); SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28); @@ -372,7 +379,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { } @Test - public void pooling3DSerialization2(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void pooling3DSerialization2(Nd4jBackend backend){ SameDiff sd = SameDiff.create(); SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java index 384d6eb22..e804e95c4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java @@ -22,12 +22,14 @@ package org.nd4j.autodiff.samediff; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.transform.GraphTransformUtil; import org.nd4j.autodiff.samediff.transform.OpPredicate; import org.nd4j.autodiff.samediff.transform.SubGraph; import org.nd4j.autodiff.samediff.transform.SubGraphPredicate; import org.nd4j.autodiff.samediff.transform.SubGraphProcessor; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; @@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -public class GraphTransformUtilTests extends BaseNd4jTest { +public class GraphTransformUtilTests extends BaseNd4jTestWithBackends { - public GraphTransformUtilTests(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -54,7 +53,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest { } @Test - public void testBasic(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasic(Nd4jBackend backend){ SameDiff sd = SameDiff.create(); SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 32); @@ -93,7 +94,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest { } @Test - public void testSubgraphReplace1(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSubgraphReplace1(Nd4jBackend backend){ SameDiff sd = SameDiff.create(); SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 4); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java index 68d6a0905..cd57673e6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java @@ -21,8 +21,10 @@ package org.nd4j.autodiff.samediff; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,11 +34,8 @@ import java.lang.reflect.Field; import static org.junit.jupiter.api.Assertions.*; -public class MemoryMgrTest extends BaseNd4jTest { +public class MemoryMgrTest extends BaseNd4jTestWithBackends { - public MemoryMgrTest(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -44,7 +43,9 @@ public class MemoryMgrTest extends BaseNd4jTest { } @Test - public void testArrayReuseTooLarge() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArrayReuseTooLarge(Nd4jBackend backend) throws Exception { ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr(); Field f = ArrayCacheMemoryMgr.class.getDeclaredField("maxCacheBytes"); @@ -97,7 +98,7 @@ public class MemoryMgrTest extends BaseNd4jTest { assertEquals(10, mmgr.getLruCacheValues().size()); //now, allocate some values: - for( int i=1; i<=10; i++ ) { + for( int i = 1; i <= 10; i++) { INDArray a1 = mmgr.allocate(true, DataType.FLOAT, 25); assertEquals(1000 - i * 100, mmgr.getCurrentCacheSize()); assertEquals(1000 - i * 100, as.getBytesSum()); @@ -116,10 +117,12 @@ public class MemoryMgrTest extends BaseNd4jTest { } @Test - public void testManyArrays(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testManyArrays(Nd4jBackend backend){ ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr(); - for( int i=0; i<1000; i++ ){ + for( int i = 0; i < 1000; i++) { mmgr.release(Nd4j.scalar(0)); } @@ -127,7 +130,7 @@ public class MemoryMgrTest extends BaseNd4jTest { assertEquals(1000, mmgr.getLruCache().size()); assertEquals(1000, mmgr.getLruCacheValues().size()); - for( int i=0; i<1000; i++ ){ + for( int i = 0; i < 1000; i++ ){ mmgr.release(Nd4j.scalar(0)); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java index 0811f140c..a6af53988 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java @@ -21,9 +21,11 @@ package org.nd4j.autodiff.samediff; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4jBackend; @@ -35,19 +37,18 @@ import java.util.Set; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -public class NameScopeTests extends BaseNd4jTest { +public class NameScopeTests extends BaseNd4jTestWithBackends { - public NameScopeTests(Nd4jBackend b){ - super(b); - } @Override - public char ordering(){ + public char ordering() { return 'c'; } @Test - public void testVariableNameScopesBasic(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariableNameScopesBasic(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable v = sd.var("x"); @@ -73,7 +74,9 @@ public class NameScopeTests extends BaseNd4jTest { } @Test - public void testOpFieldsAndNames(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOpFieldsAndNames(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable x = sd.var("x", DataType.FLOAT, 1); @@ -151,7 +154,9 @@ public class NameScopeTests extends BaseNd4jTest { } @Test - public void testNoNesting(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNoNesting(Nd4jBackend backend) { SameDiff SD = SameDiff.create(); SDVariable a = SD.constant(4); @@ -168,7 +173,9 @@ public class NameScopeTests extends BaseNd4jTest { } @Test - public void testNoTesting2(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNoTesting2(Nd4jBackend backend) { SameDiff SD = SameDiff.create(); SDVariable a = SD.constant(4); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java index fae729e6d..c13229451 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffMultiThreadTests.java @@ -21,21 +21,16 @@ package org.nd4j.autodiff.samediff; import lombok.extern.slf4j.Slf4j; -import org.junit.jupiter.api.Disabled; - import org.junit.jupiter.api.Test; - -import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.common.primitives.AtomicBoolean; import org.nd4j.common.tests.BaseND4JTest; -import org.nd4j.imports.tfgraphs.TFGraphTestZooModels; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.AtomicBoolean; -import org.nd4j.common.resources.Resources; +import org.nd4j.linalg.factory.Nd4jBackend; -import java.io.File; -import java.nio.file.Path; import java.util.Collections; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; @@ -55,7 +50,9 @@ public class SameDiffMultiThreadTests extends BaseND4JTest { } @Test - public void testSimple() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSimple(Nd4jBackend backend) throws Exception { int nThreads = 4; int nRuns = 1000; @@ -103,48 +100,6 @@ public class SameDiffMultiThreadTests extends BaseND4JTest { } } - @Test - @Disabled //2020/03/24 AB - https://github.com/eclipse/deeplearning4j/issues/8802 - public void testMobilenet(@TempDir Path testDir) throws Exception { - TFGraphTestZooModels.currentTestDir = testDir.toFile(); - File f = Resources.asFile("tf_graphs/zoo_models/mobilenet_v2_1.0_224/tf_model.txt"); - SameDiff sd = TFGraphTestZooModels.LOADER.apply(f, "mobilenet_v2_1.0_224"); -// System.out.println(sd.summary()); - - int nThreads = 4; - int nRuns = 30; - INDArray[] inputArrs = new INDArray[nThreads]; - INDArray[] expOut = new INDArray[nThreads]; - for( int i=0; i 2) - inputArrs[i] = Nd4j.rand(DataType.FLOAT, 1, 224, 224, 3); - else if(i == 1) - inputArrs[i] = Nd4j.zeros(DataType.FLOAT, 1, 224, 224, 3); - else if(i == 2) - inputArrs[i] = Nd4j.ones(DataType.FLOAT, 1, 224, 224, 3); - - expOut[i] = sd.outputSingle(Collections.singletonMap("input", inputArrs[i]), "MobilenetV2/Predictions/Reshape_1"); - Nd4j.getExecutioner().commit(); - } - - AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads]; - AtomicInteger[] counters = new AtomicInteger[nThreads]; - Semaphore s = new Semaphore(nThreads); - CountDownLatch latch = new CountDownLatch(nThreads); - - doTest(sd, nThreads, nRuns, inputArrs, expOut, "input", "MobilenetV2/Predictions/Reshape_1", failuresByThread, counters, s, latch); - - s.release(nThreads); - latch.await(); - - for(int i = 0; i < nThreads; i++) { - assertFalse( failuresByThread[i].get(),"Thread " + i + " failed"); - } - - for(int i = 0; i < nThreads; i++) { - assertEquals( nRuns, counters[i].get(),"Thread " + i + " number of runs"); - } - } public static void doTest(SameDiff sd, int nThreads, int nRuns, INDArray[] inputArrs, INDArray[] expOut, diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffOutputTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffOutputTest.java index 90dc1a812..2bd6d0d8c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffOutputTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffOutputTest.java @@ -21,7 +21,9 @@ package org.nd4j.autodiff.samediff; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -31,14 +33,13 @@ import org.nd4j.linalg.learning.config.Sgd; import static org.junit.jupiter.api.Assertions.assertTrue; -public class SameDiffOutputTest extends BaseNd4jTest { +public class SameDiffOutputTest extends BaseNd4jTestWithBackends { - public SameDiffOutputTest(Nd4jBackend backend) { - super(backend); - } @Test - public void outputTest(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void outputTest(Nd4jBackend backend){ DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10)); SameDiff sd = SameDiff.create(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java index 30918bb8a..1ca6bbceb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffSpecifiedLossVarsTests.java @@ -21,7 +21,9 @@ package org.nd4j.autodiff.samediff; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -35,19 +37,18 @@ import static junit.framework.TestCase.assertNotNull; import static junit.framework.TestCase.assertNull; import static org.junit.jupiter.api.Assertions.*; -public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest { +public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends { - public SameDiffSpecifiedLossVarsTests(Nd4jBackend b){ - super(b); - } @Override - public char ordering(){ + public char ordering() { return 'c'; } @Test - public void testSpecifiedLoss1(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSpecifiedLoss1(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4); ph1.setArray(Nd4j.create(DataType.FLOAT, 3, 4)); @@ -68,7 +69,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest { } @Test - public void testSpecifiedLoss2(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSpecifiedLoss2(Nd4jBackend backend) { for( int i=0; i<2; i++ ) { SameDiff sd = SameDiff.create(); SDVariable ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4); @@ -121,7 +124,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest { @Test - public void testTrainingDifferentLosses(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTrainingDifferentLosses(Nd4jBackend backend) { //Net with 2 losses: train on the first one, then change losses //Also check that if modifying via add/setLossVariables the training config changes diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index a8717e78d..3941b6cea 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -40,6 +40,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.api.OutAndGrad; @@ -55,7 +57,7 @@ import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.regression.RegressionEvaluation; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -89,13 +91,10 @@ import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.weightinit.impl.UniformInitScheme; @Slf4j -public class SameDiffTests extends BaseNd4jTest { +public class SameDiffTests extends BaseNd4jTestWithBackends { private DataType initialType; - public SameDiffTests(Nd4jBackend b) { - super(b); - } @Override public char ordering() { @@ -110,7 +109,7 @@ public class SameDiffTests extends BaseNd4jTest { } @BeforeEach - public void before() { + public void before(Nd4jBackend backend) { Nd4j.create(1); initialType = Nd4j.dataType(); @@ -119,7 +118,7 @@ public class SameDiffTests extends BaseNd4jTest { } @AfterEach - public void after() { + public void after(Nd4jBackend backend) { Nd4j.setDataType(initialType); NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); @@ -146,7 +145,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testVariableNaming_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariableNaming_1(Nd4jBackend backend) { val sd = SameDiff.create(); val input = sd.var("inp", new long[]{2, 3}); @@ -163,13 +164,17 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testAddArgsAndOutput() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddArgsAndOutput(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); val varOne = sameDiff.var("one", Nd4j.ones(2)); } @Test - public void testMseBackwards() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMseBackwards(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); @@ -196,7 +201,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testEvalVariable() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvalVariable(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray ones = Nd4j.ones(4); INDArray twos = ones.add(ones); @@ -207,7 +214,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.FLOAT)).reshape(1, 4); SDVariable x = sameDiff.var("x", arr); @@ -219,7 +228,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testAddEval() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddEval(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray x = Nd4j.scalar(1.0); INDArray y = Nd4j.scalar(2.0); @@ -235,7 +246,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testWeightedXentWithLogits() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testWeightedXentWithLogits(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray targets = Nd4j.create(new long[]{1, 5}); INDArray inputs = Nd4j.create(new long[]{1, 5}); @@ -252,7 +265,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testMseForward() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMseForward(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); @@ -278,7 +293,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testDistance() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDistance(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(2, 2); SDVariable x = sameDiff.var("x", arr); @@ -291,7 +308,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testTensorGradMmul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTensorGradMmul(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(2, 2); SDVariable x = sameDiff.var("x", arr); @@ -304,7 +323,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testEval() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEval(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Nd4j.linspace(1, 4, 4); SDVariable x = sameDiff.var("x", arr); @@ -315,7 +336,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testFunctionInputsAndArgs() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFunctionInputsAndArgs(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); SDVariable var = sameDiff.var("one", Nd4j.scalar(1.0)); SDVariable variable2 = sameDiff.var("two", Nd4j.scalar(1.0)); @@ -326,7 +349,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testCrossSameDiffVariableInitWithAlloc() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCrossSameDiffVariableInitWithAlloc(Nd4jBackend backend) { SameDiff first = SameDiff.create(); SameDiff second = SameDiff.create(); @@ -338,7 +363,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testCrossSameDiffVariableInitWithPlaceHolder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCrossSameDiffVariableInitWithPlaceHolder(Nd4jBackend backend) { SameDiff first = SameDiff.create(); SameDiff second = SameDiff.create(); @@ -352,7 +379,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testVariableArrayReference() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariableArrayReference(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); SDVariable arr = sameDiff.var("one", new long[]{2, 2}); assertArrayEquals(new long[]{2, 2}, arr.getShape()); @@ -361,7 +390,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testEvalAddSelf() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvalAddSelf(Nd4jBackend backend) { /** * Note this test fails yet due to needing * to validate simple cases like x * x @@ -377,7 +408,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testEvalAdd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvalAdd(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Nd4j.linspace(1, 4, 4); INDArray yArr = arr.dup(); @@ -394,7 +427,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testDup() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDup(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 8, 8)).reshape(2, 2, 2); SDVariable x = sameDiff.var("x", arr); @@ -404,29 +439,25 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testElementWiseDivAndRDiv() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testElementWiseDivAndRDiv(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray ones = Nd4j.ones(4); INDArray toDivBy = Nd4j.valueArrayOf(4, 0.25); Map xAndY = new HashMap<>(); xAndY.put("x", ones); xAndY.put("y", toDivBy); - sameDiff.defineFunction("div", new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable x = sameDiff.var("x", inputs.get("x")); - SDVariable y = sameDiff.var("y", inputs.get("y")); - return new SDVariable[]{x.div("out", y)}; - } + sameDiff.defineFunction("div", (sameDiff1, inputs, variableInputs) -> { + SDVariable x = sameDiff1.var("x", inputs.get("x")); + SDVariable y = sameDiff1.var("y", inputs.get("y")); + return new SDVariable[]{x.div("out", y)}; }, xAndY); - sameDiff.defineFunction("rdiv", new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable x = sameDiff.var("x", inputs.get("x")); - SDVariable y = sameDiff.var("y", inputs.get("y")); - return new SDVariable[]{x.rdiv("out", y)}; - } + sameDiff.defineFunction("rdiv", (sameDiff12, inputs, variableInputs) -> { + SDVariable x = sameDiff12.var("x", inputs.get("x")); + SDVariable y = sameDiff12.var("y", inputs.get("y")); + return new SDVariable[]{x.rdiv("out", y)}; }, xAndY); INDArray assertionForDiv = Nd4j.valueArrayOf(4, 4.0); @@ -438,17 +469,16 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testNegativeGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNegativeGradient(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray ones = Nd4j.ones(4); Map xAndY = new HashMap<>(); xAndY.put("x", ones); - sameDiff.defineFunction("neg", new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable x = sameDiff.var("x", inputs.get("x")); - return new SDVariable[]{sameDiff.math().neg("out", x)}; - } + sameDiff.defineFunction("neg", (sameDiff1, inputs, variableInputs) -> { + SDVariable x = sameDiff1.var("x", inputs.get("x")); + return new SDVariable[]{sameDiff1.math().neg("out", x)}; }, xAndY); INDArray assertionForDiv = Nd4j.valueArrayOf(4, -1); @@ -458,18 +488,17 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testSumOp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSumOp(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray sumInput = Nd4j.linspace(1, 4, 4).reshape(2, 2); Map inputs = new HashMap<>(); inputs.put("x", sumInput); - sameDiff.defineFunction("sum", new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable input = sameDiff.var("x", inputs.get("x")); - SDVariable sum = sameDiff.sum("sum", input, 1); - return new SDVariable[]{sum}; - } + sameDiff.defineFunction("sum", (sameDiff1, inputs1, variableInputs) -> { + SDVariable input = sameDiff1.var("x", inputs1.get("x")); + SDVariable sum = sameDiff1.sum("sum", input, 1); + return new SDVariable[]{sum}; }, inputs); INDArray assertion = sumInput.sum(1); @@ -480,7 +509,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testVariableReferenceNoFunction() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariableReferenceNoFunction(Nd4jBackend backend) { /** * Creating a variable should not create a differential function. */ @@ -491,7 +522,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testVariableWithFunction() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariableWithFunction(Nd4jBackend backend) { /** * A variable's function should be null * when just a variable but @@ -507,7 +540,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testUpdateVariable() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUpdateVariable(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); SDVariable one = sameDiff.one("one", new long[]{1, 1}); one.rename("one-diff"); @@ -516,7 +551,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testDefineFunctionArrayExistence() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDefineFunctionArrayExistence(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); String testFunctionName = "testfunction"; SDVariable[] inputVars = new SDVariable[]{ @@ -525,12 +562,7 @@ public class SameDiffTests extends BaseNd4jTest { }; - SameDiff functionDef = sameDiff.defineFunction(testFunctionName, new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - return new SDVariable[]{variableInputs[0].add(variableInputs[1])}; - } - }, inputVars); + SameDiff functionDef = sameDiff.defineFunction(testFunctionName, (sameDiff1, inputs, variableInputs) -> new SDVariable[]{variableInputs[0].add(variableInputs[1])}, inputVars); //1 input plus 2 outputs assertEquals(3, functionDef.variables().size()); @@ -539,7 +571,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testAutoBroadcastAddMatrixVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAutoBroadcastAddMatrixVector(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray row = Nd4j.ones(2); @@ -552,14 +586,18 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testNegativeOneShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNegativeOneShape(Nd4jBackend backend) { val sd = SameDiff.create(); SDVariable var = sd.placeHolder("test", DataType.FLOAT, -1, 3); assertTrue(var.isPlaceHolder()); } @Test - public void testShapeResolutionMinus1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShapeResolutionMinus1(Nd4jBackend backend) { int nIn = 3; int nOut = 4; @@ -603,7 +641,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testLabelInputPlaceHolderSgd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLabelInputPlaceHolderSgd(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); @@ -641,7 +681,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testSequentialMeansPlaceholder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSequentialMeansPlaceholder(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); for (int dim0 : new int[]{10, -1}) { String msg = "Dimension 0 = " + dim0; @@ -663,7 +705,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testReductionShapes1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReductionShapes1(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", new long[]{10, 9, 8}); @@ -680,7 +724,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testReductionShapes2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReductionShapes2(Nd4jBackend backend) { SameDiff sd2 = SameDiff.create(); SDVariable in2 = sd2.var("in", new long[]{10, 9, 8}); @@ -705,7 +751,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testNames() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNames(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable in1 = sd.var("in", new long[]{3, 2}); SDVariable in2 = sd.var("in2", new long[]{3, 3}); @@ -721,27 +769,26 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testRunLogisticRegression() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRunLogisticRegression(Nd4jBackend backend) { Map vars = this.variablesForInput(); SameDiff outside = SameDiff.create(); - outside.defineFunction("activate", new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - sameDiff.enableDebugMode(); - SDVariable x = sameDiff.var("x", inputs.get("x")); - SDVariable w = sameDiff.var("w", inputs.get("w")); - SDVariable y = sameDiff.var("y", inputs.get("y")); - SDVariable activation = sameDiff.nn().sigmoid("activation", sameDiff.mmul("mmul", x, w)); - SDVariable oneMinusY = y.rsub("oneminusy", 1.0); - SDVariable oneMinusPredictions = activation.rsub("oneminusactivations", 1.0); - SDVariable outputTimesY = y.mul("output * y", activation); - SDVariable yHat = oneMinusPredictions.mul("yhat", oneMinusY); - SDVariable probs = outputTimesY.add("probs", yHat); - SDVariable logProbs = sameDiff.math().log("logprob", probs); - SDVariable ret = sameDiff.sum("totalsum", logProbs, Integer.MAX_VALUE); - SDVariable ret2 = sameDiff.math().neg("negtotalsum", ret); - return new SDVariable[]{ret2}; - } + outside.defineFunction("activate", (sameDiff, inputs, variableInputs) -> { + sameDiff.enableDebugMode(); + SDVariable x = sameDiff.var("x", inputs.get("x")); + SDVariable w = sameDiff.var("w", inputs.get("w")); + SDVariable y = sameDiff.var("y", inputs.get("y")); + SDVariable activation = sameDiff.nn().sigmoid("activation", sameDiff.mmul("mmul", x, w)); + SDVariable oneMinusY = y.rsub("oneminusy", 1.0); + SDVariable oneMinusPredictions = activation.rsub("oneminusactivations", 1.0); + SDVariable outputTimesY = y.mul("output * y", activation); + SDVariable yHat = oneMinusPredictions.mul("yhat", oneMinusY); + SDVariable probs = outputTimesY.add("probs", yHat); + SDVariable logProbs = sameDiff.math().log("logprob", probs); + SDVariable ret = sameDiff.sum("totalsum", logProbs, Integer.MAX_VALUE); + SDVariable ret2 = sameDiff.math().neg("negtotalsum", ret); + return new SDVariable[]{ret2}; }, vars); SameDiff activation = outside.getFunction("activate"); @@ -758,7 +805,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testTransposeWithVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTransposeWithVector(Nd4jBackend backend) { val sd = SameDiff.create(); val matrix = Nd4j.linspace(1, 12, 12).reshape(4, 3); val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1); @@ -770,22 +819,20 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSimpleDefineFunction() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSimpleDefineFunction(Nd4jBackend backend) { SameDiff sameDiffOuter = SameDiff.create(); Map inputs = variablesForInput(); inputs.remove("y"); String logisticForward = "logisticPredictions"; - sameDiffOuter.defineFunction(logisticForward, new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - - SDVariable input = sameDiff.var("x", inputs.get("x")); - SDVariable w = sameDiff.var("w", inputs.get("w")); - SDVariable preOutput = sameDiff.mmul(input, w); - SDVariable sigmoid = sameDiff.nn().sigmoid(preOutput); - return new SDVariable[]{sigmoid}; - } + sameDiffOuter.defineFunction(logisticForward, (sameDiff, inputs1, variableInputs) -> { + SDVariable input = sameDiff.var("x", inputs1.get("x")); + SDVariable w = sameDiff.var("w", inputs1.get("w")); + SDVariable preOutput = sameDiff.mmul(input, w); + SDVariable sigmoid = sameDiff.nn().sigmoid(preOutput); + return new SDVariable[]{sigmoid}; }, inputs); assertEquals(1, sameDiffOuter.definedFunctionNames().size()); @@ -794,7 +841,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSumGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSumGradient(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); SDVariable twoByTwo = sameDiff.var("initial", Nd4j.linspace(1, 4, 4, DataType.FLOAT).reshape(2, 2)); SDVariable sum = sameDiff.sum(twoByTwo, Integer.MAX_VALUE); @@ -804,18 +853,17 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testRsubScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRsubScalar(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); Map params = new HashMap<>(); INDArray var = Nd4j.valueArrayOf(4, 2); params.put("x", var); - sameDiff.defineFunction("rsubop", new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable input = sameDiff.var("x", inputs.get("x")); - SDVariable ret = input.rsub("rsub", 1.0); - return new SDVariable[]{ret}; - } + sameDiff.defineFunction("rsubop", (sameDiff1, inputs, variableInputs) -> { + SDVariable input = sameDiff1.var("x", inputs.get("x")); + SDVariable ret = input.rsub("rsub", 1.0); + return new SDVariable[]{ret}; }, params); SameDiff logisticGraph = sameDiff.getFunction("rsubop"); @@ -825,28 +873,24 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testFunctionScalarResultPropagation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFunctionScalarResultPropagation(Nd4jBackend backend) { SameDiff sameDiffOuter = SameDiff.create(); Map inputs = variablesForInput(); - sameDiffOuter.defineFunction("logisticPredictions", new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable input = sameDiff.var("x", inputs.get("x")); - SDVariable w = sameDiff.var("w", inputs.get("w")); - SDVariable preOutput = sameDiff.mmul(input, w); - SDVariable sigmoid = sameDiff.nn().sigmoid(preOutput); - return new SDVariable[]{sigmoid}; - } + sameDiffOuter.defineFunction("logisticPredictions", (sameDiff, inputs12, variableInputs) -> { + SDVariable input = sameDiff.var("x", inputs12.get("x")); + SDVariable w = sameDiff.var("w", inputs12.get("w")); + SDVariable preOutput = sameDiff.mmul(input, w); + SDVariable sigmoid = sameDiff.nn().sigmoid(preOutput); + return new SDVariable[]{sigmoid}; }, inputs); - sameDiffOuter.defineFunction("oneminuspredictions", new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable y = sameDiff.var("y", inputs.get("y")); - SDVariable oneMinusPredictions = y.rsub("rsub", 1.0); - return new SDVariable[]{oneMinusPredictions}; - } + sameDiffOuter.defineFunction("oneminuspredictions", (sameDiff, inputs1, variableInputs) -> { + SDVariable y = sameDiff.var("y", inputs1.get("y")); + SDVariable oneMinusPredictions = y.rsub("rsub", 1.0); + return new SDVariable[]{oneMinusPredictions}; }, inputs); SameDiff logisticGraph = sameDiffOuter.getFunction("oneminuspredictions"); @@ -860,7 +904,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testMmul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmul(Nd4jBackend backend) { SameDiff sameDiffOuter = SameDiff.create(); Map inputs = variablesForInput(); SDVariable x = sameDiffOuter.var("x", inputs.get("x")); @@ -870,32 +916,28 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testGraphBuilding() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGraphBuilding(Nd4jBackend backend) { final SameDiff sameDiffOuter = SameDiff.create(); Map inputs = variablesForInput(); - sameDiffOuter.defineFunction("logisticPredictions", new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable input = sameDiff.var("x", inputs.get("x")); - SDVariable w = sameDiff.var("w", inputs.get("w")); - SDVariable y = sameDiff.var("y", inputs.get("y")); - SDVariable preOutput = sameDiff.mmul(input, w); - SDVariable sigmoid = sameDiff.nn().sigmoid(preOutput); + sameDiffOuter.defineFunction("logisticPredictions", (sameDiff, inputs1, variableInputs) -> { + SDVariable input = sameDiff.var("x", inputs1.get("x")); + SDVariable w = sameDiff.var("w", inputs1.get("w")); + SDVariable y = sameDiff.var("y", inputs1.get("y")); + SDVariable preOutput = sameDiff.mmul(input, w); + SDVariable sigmoid = sameDiff.nn().sigmoid(preOutput); - return new SDVariable[]{sigmoid}; - } + return new SDVariable[]{sigmoid}; }, inputs); - sameDiffOuter.defineFunction("loss", new SameDiffFunctionDefinition() { - @Override - public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable outputs = sameDiffOuter.invokeFunctionOn("logisticPredictions", sameDiff); - SDVariable y = sameDiff.getVariable("y"); - SDVariable outputTimesY = outputs.mul(y); - return new SDVariable[]{outputTimesY}; + sameDiffOuter.defineFunction("loss", (sameDiff, inputs12, variableInputs) -> { + SDVariable outputs = sameDiffOuter.invokeFunctionOn("logisticPredictions", sameDiff); + SDVariable y = sameDiff.getVariable("y"); + SDVariable outputTimesY = outputs.mul(y); + return new SDVariable[]{outputTimesY}; - } }, inputs); SameDiff logisticPrediction = sameDiffOuter.getFunction("logisticPredictions"); @@ -906,7 +948,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testScalarAdd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarAdd(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); SDVariable twoByTwo = sameDiff.var("first", Nd4j.linspace(1, 4, 4).reshape('c', 2, 2)); SDVariable add = twoByTwo.add(1.0); @@ -917,7 +961,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testSums() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSums(Nd4jBackend backend) { SameDiff sameDiff = SameDiff.create(); INDArray ones = Nd4j.ones(7, 4); SDVariable sdVariable = sameDiff.var("ones", ones); @@ -929,7 +975,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testDenseLayerForwardPass() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDenseLayerForwardPass(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -958,7 +1006,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testActivationBackprop() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testActivationBackprop(Nd4jBackend backend) { Activation[] afns = new Activation[]{ Activation.TANH, @@ -1053,7 +1103,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testPlaceholderReduceSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPlaceholderReduceSimple(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable v = sd.var("in", new long[]{-1, 10}); SDVariable vSum = sd.sum(v, 1); //Exception here @@ -1061,7 +1113,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testSequentialMeans() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSequentialMeans(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", new long[]{10, 10, 10}); SDVariable mean1 = sd.mean(in, 2); //[10,10] out @@ -1069,7 +1123,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testBatchNormTest() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBatchNormTest(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray input = Nd4j.rand(1, 10); @@ -1094,7 +1150,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testLrn() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLrn(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray input = Nd4j.create(new float[]{4, 4, 4, 4}, new long[]{1, 4, 1, 1}); @@ -1119,7 +1177,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testMoments() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMoments(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray input = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2}); @@ -1143,7 +1203,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testNormalizeMoments() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNormalizeMoments(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray counts = Nd4j.create(new float[]{2}, new long[]{1, 1}); @@ -1174,7 +1236,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testDepthWiseConv2dBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDepthWiseConv2dBasic(Nd4jBackend backend) { int nIn = 3; int depthWise = 4; int kH = 2; @@ -1212,7 +1276,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void validateMeanDiff() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void validateMeanDiff(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.rand(3, 4); @@ -1234,7 +1300,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void validateSumDiff() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void validateSumDiff(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.rand(3, 4); @@ -1256,7 +1324,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void validateStdevDiff() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void validateStdevDiff(Nd4jBackend backend) { for (boolean biasCorrected : new boolean[]{true, false}) { Nd4j.getRandom().setSeed(12345); @@ -1286,7 +1356,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void validateVarDiff() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void validateVarDiff(Nd4jBackend backend) { for (boolean biasCorrected : new boolean[]{true, false}) { Nd4j.getRandom().setSeed(12345); @@ -1315,7 +1387,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void validateMinDiff() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void validateMinDiff(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.rand(3, 4); @@ -1340,7 +1414,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void validateMaxDiff() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void validateMaxDiff(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.rand(DataType.DOUBLE, 3, 4); @@ -1364,7 +1440,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void validateProdDiff() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void validateProdDiff(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.rand(3, 4); @@ -1388,7 +1466,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSquare() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSquare(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int mb = 5; @@ -1410,7 +1490,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testExpandDims() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExpandDims(Nd4jBackend backend) { for (int i = 0; i <= 2; i++) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", Nd4j.create(2, 3)); @@ -1434,7 +1516,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testZerosLike() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testZerosLike(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable var0 = sd.var("in", DataType.DOUBLE, new long[]{3, 4}); SDVariable out = sd.zerosLike("out", var0); @@ -1448,7 +1532,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testOnesLike() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOnesLike(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable var0 = sd.var("in", new long[]{3, 4}); SDVariable out = sd.onesLike("out", var0); @@ -1463,7 +1549,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testOnesLikeBackprop() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOnesLikeBackprop(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable var0 = sd.var("in", new long[]{3, 4}); SDVariable ones = sd.onesLike("ones", var0); @@ -1479,7 +1567,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testManhattanAlongDim0() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testManhattanAlongDim0(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray a = Nd4j.rand(new long[]{3, 4, 5}); @@ -1494,7 +1584,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testJaccardDistance() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJaccardDistance(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray a = Nd4j.rand(new long[]{3, 4}).addi(0.1); @@ -1520,7 +1612,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testPairwiseBooleanTransforms() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPairwiseBooleanTransforms(Nd4jBackend backend) { /* eq, neq, gt, lt, gte, lte, or, and, xor */ @@ -1606,7 +1700,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testBooleanChecks() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBooleanChecks(Nd4jBackend backend) { /* isNonDecreasing, */ @@ -1650,7 +1746,9 @@ public class SameDiffTests extends BaseNd4jTest { @Disabled(/*AS - 20191114 https://github.com/eclipse/deeplearning4j/issues/8393*/) @Test - public void testIsStrictlyIncShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsStrictlyIncShape(Nd4jBackend backend) { int nOut = 0; int minibatch = 0; @@ -1661,7 +1759,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testExpandDims2d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExpandDims2d(Nd4jBackend backend) { val origShape = new long[]{3, 4}; for (int i = 0; i < 3; i++) { @@ -1698,7 +1798,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSqueezeDims() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSqueezeDims(Nd4jBackend backend) { val origShape = new long[]{3, 4, 5}; for (int i = 0; i < 3; i++) { @@ -1739,7 +1841,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testExpandSqueezeChain() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExpandSqueezeChain(Nd4jBackend backend) { val origShape = new long[]{3, 4}; @@ -1763,7 +1867,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSqueezeExpandChain() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSqueezeExpandChain(Nd4jBackend backend) { val origShape = new long[]{3, 4, 5}; @@ -1791,7 +1897,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testConfusionMatrix() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConfusionMatrix(Nd4jBackend backend) { INDArray labels = Nd4j.createFromArray(1, 2, 4); INDArray pred = Nd4j.createFromArray(2, 2, 4); INDArray weights = Nd4j.createFromArray(10, 100, 1000); @@ -1810,7 +1918,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testArgMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArgMax(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); for (val dim : new int[][]{{0}, {1}, {Integer.MAX_VALUE}, {0, 1}, {}}) { @@ -1829,7 +1939,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testArgMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArgMin(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); @@ -1849,7 +1961,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testScatterAdd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterAdd(Nd4jBackend backend) { INDArray arr1 = Nd4j.zeros(3, 3); INDArray arr2 = Nd4j.createFromArray(0, 1); INDArray arr3 = Nd4j.ones(2, 3); @@ -1871,7 +1985,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testScatterMul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterMul(Nd4jBackend backend) { INDArray arr1 = Nd4j.ones(3, 3); INDArray arr2 = Nd4j.createFromArray(0, 1); INDArray arr3 = Nd4j.zeros(2, 3); @@ -1893,7 +2009,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testScatterSub() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterSub(Nd4jBackend backend) { INDArray arr1 = Nd4j.ones(3, 3); INDArray arr2 = Nd4j.createFromArray(0, 1); INDArray arr3 = Nd4j.ones(2, 3); @@ -1915,7 +2033,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testScatterDiv() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterDiv(Nd4jBackend backend) { INDArray arr1 = Nd4j.ones(3, 3); INDArray arr2 = Nd4j.createFromArray(0, 1); INDArray arr3 = Nd4j.ones(2, 3).assign(2); @@ -1936,7 +2056,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testScatterMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterMax(Nd4jBackend backend) { INDArray arr1 = Nd4j.ones(3, 3); INDArray arr2 = Nd4j.createFromArray(0, 1); INDArray arr3 = Nd4j.ones(2, 3).assign(2); @@ -1957,7 +2079,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testScatterMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterMin(Nd4jBackend backend) { INDArray arr1 = Nd4j.ones(3, 3); INDArray arr2 = Nd4j.createFromArray(1, 2); INDArray arr3 = Nd4j.ones(2, 3).assign(-2.0f); @@ -1978,7 +2102,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testReciprocal() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReciprocal(Nd4jBackend backend) { INDArray inArr = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray expected = Nd4j.onesLike(inArr).divi(inArr); SameDiff sd = SameDiff.create(); @@ -1989,7 +2115,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testGather2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGather2(Nd4jBackend backend) { INDArray in = Nd4j.rand(DataType.FLOAT, 10, 10); INDArray indices = Nd4j.createFromArray(0, 1, 5); @@ -2007,7 +2135,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testGatherOp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGatherOp(Nd4jBackend backend) { INDArray in = Nd4j.rand(DataType.DOUBLE, 10, 10); INDArray indices = Nd4j.createFromArray(0, 1, 5); @@ -2036,7 +2166,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testConditions() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConditions(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); @@ -2073,7 +2205,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testGet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGet(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray arr = Nd4j.linspace(1, 100, 100).reshape('c', 10L, 10L); @@ -2101,7 +2235,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testGetRank3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRank3(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray arr = Nd4j.linspace(1, 1000, 1000).reshape('c', 10, 10, 10); @@ -2139,7 +2275,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testTensorArray1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTensorArray1(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); TensorArray tensorArray = sd.tensorArray(DataType.FLOAT); INDArray arr1 = Nd4j.create(new double[]{1, 2, 3, 4}, new int[]{2, 2}); @@ -2154,7 +2292,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testTensorArray2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTensorArray2(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); TensorArray tensorArray = sd.tensorArray(DataType.FLOAT); INDArray arr1 = Nd4j.create(new double[]{1, 2, 3, 4}, new int[]{2, 2}); @@ -2169,7 +2309,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testTensorArray3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTensorArray3(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); TensorArray tensorArray = sd.tensorArray(DataType.FLOAT); INDArray arr1 = Nd4j.create(new double[]{1, 2, 3, 4}, new int[]{2, 2}); @@ -2186,7 +2328,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testFill() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFill(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray shape = Nd4j.createFromArray(2, 2); INDArray expOut = Nd4j.valueArrayOf(new int[]{2, 2}, 42.0); @@ -2206,7 +2350,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testPermute() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPermute(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); INDArray arr = Nd4j.create(new double[]{ ///////////// @@ -2243,7 +2389,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testExecutionDifferentShapesAccumAlongDim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExecutionDifferentShapesAccumAlongDim(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", Nd4j.linspace(1, 12, 12).reshape(3, 4)); @@ -2263,7 +2411,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testExecutionDifferentShapesIndexAccumAlongDim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExecutionDifferentShapesIndexAccumAlongDim(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", Nd4j.linspace(1, 12, 12).reshape(3, 4)); @@ -2283,7 +2433,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testExternalErrorsSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExternalErrorsSimple(Nd4jBackend backend) { INDArray externalGrad = Nd4j.linspace(1, 12, 12).reshape(3, 4); SameDiff sd = SameDiff.create(); @@ -2316,7 +2468,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testUpdatingGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUpdatingGradient(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -2346,7 +2500,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testUpdatingGradientSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUpdatingGradientSimple(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", Nd4j.linspace(1, 12, 12).reshape(3, 4)); SDVariable out = in.mul(2.0); @@ -2374,7 +2530,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testShapeUpdating() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShapeUpdating(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", DataType.FLOAT, 3, 5); @@ -2414,7 +2572,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testMultiOutput1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultiOutput1(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", Nd4j.create(3, 4)); @@ -2433,7 +2593,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testMultiOutput2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultiOutput2(Nd4jBackend backend) { //Edge case: no functions SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", Nd4j.scalar(0.0)); @@ -2451,7 +2613,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void sameDiffPlaceholderGrad() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void sameDiffPlaceholderGrad(Nd4jBackend backend) { INDArray x = Nd4j.ones(2, 2); INDArray y = Nd4j.ones(2, 2); @@ -2472,7 +2636,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testConvertToConstant() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConvertToConstant(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -2514,7 +2680,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testPlaceholderToConstant() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPlaceholderToConstant(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -2556,7 +2724,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testConvertToVariable() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConvertToVariable(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); @@ -2596,7 +2766,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testDoubleUseOfArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDoubleUseOfArray(Nd4jBackend backend) { //If array is reused, gradient check will fail INDArray a = Nd4j.rand(DataType.DOUBLE, new int[]{3, 4}); SameDiff sd = SameDiff.create(); @@ -2615,7 +2787,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testMultiGradientRecurrent() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultiGradientRecurrent(Nd4jBackend backend) { final INDArray input = Nd4j.rand(DataType.DOUBLE, new int[]{3, 4, 2}); final INDArray[] output = new INDArray[(int) input.size(2)]; for (int i = 0; i < input.size(2); i++) { @@ -2659,7 +2833,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testMultiGradientManualRecurrent() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultiGradientManualRecurrent(Nd4jBackend backend) { final INDArray input = Nd4j.rand(DataType.DOUBLE, new int[]{3, 4, 2}); final INDArray[] output = new INDArray[(int) input.size(2)]; for (int i = 0; i < input.size(2); i++) { @@ -2701,7 +2877,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testMultiGradient() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultiGradient(Nd4jBackend backend) { final INDArray input = Nd4j.rand(DataType.DOUBLE, new int[]{3, 4, 2}); SameDiff sd = SameDiff.create(); final SDVariable sdInput = sd.var("input", input); @@ -2720,7 +2898,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testNonScalarOutput1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNonScalarOutput1(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable linspace = sd.linspace("at", DataType.DOUBLE, 1, 15, 15); SDVariable a = sd.reshape("a", linspace, 3, 5); @@ -2741,7 +2921,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testNonScalarOutput2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNonScalarOutput2(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable a = sd.reshape("a", sd.linspace("at", DataType.DOUBLE, 1, 15, 15), 3, 5); SDVariable b = sd.var("b", Nd4j.ones(DataType.DOUBLE, 3, 5)); @@ -2761,7 +2943,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testNonScalarOutput3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNonScalarOutput3(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable a = sd.reshape("a", sd.linspace("at", DataType.DOUBLE, 1, 15, 15), 3, 5); SDVariable b = sd.var("b", Nd4j.ones(DataType.DOUBLE, 3, 5));//.add(3); @@ -2781,7 +2965,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testNonScalarOutput4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNonScalarOutput4(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable a = sd.var("a", DataType.DOUBLE, 3, 4); SDVariable b = sd.placeHolder("b", DataType.DOUBLE, 4, 5); @@ -2803,7 +2989,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testNonScalarOutput5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNonScalarOutput5(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable linspace = sd.linspace(DataType.DOUBLE, 1, 75, 75); SDVariable a = sd.reshape("a", linspace, 15, 5); @@ -2824,7 +3012,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSameDiffBackprop1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSameDiffBackprop1(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); final SDVariable a = sd.var("a", Nd4j.rand(4, 4)); final SDVariable b = sd.var("b", Nd4j.rand(4, 4)); @@ -2838,7 +3028,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSameDiffNoGradForConstantAndPlaceholder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSameDiffNoGradForConstantAndPlaceholder(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); final SDVariable a = sd.var("a", Nd4j.rand(4, 4)); final SDVariable b = sd.constant("b", Nd4j.rand(4, 4)); @@ -2853,7 +3045,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testDuplicateNamePlaceholder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDuplicateNamePlaceholder(Nd4jBackend backend) { for (int i = 0; i < 2; i++) { SameDiff sd = SameDiff.create(); @@ -2865,7 +3059,7 @@ public class SameDiffTests extends BaseNd4jTest { } catch (Throwable t) { String m = t.getMessage(); assertNotNull(m); - assertTrue(m.contains("already exists"),m); + assertTrue(m.contains("already exists"),m); } try { @@ -2874,7 +3068,7 @@ public class SameDiffTests extends BaseNd4jTest { } catch (Throwable t) { String m = t.getMessage(); assertNotNull(m); - assertTrue(m.contains("already exists"),m); + assertTrue(m.contains("already exists"),m); } try { @@ -2892,7 +3086,7 @@ public class SameDiffTests extends BaseNd4jTest { } catch (Throwable t) { String m = t.getMessage(); assertNotNull(m); - assertTrue(m.contains("already exists"),m); + assertTrue(m.contains("already exists"),m); } try { @@ -2901,13 +3095,15 @@ public class SameDiffTests extends BaseNd4jTest { } catch (Throwable t) { String m = t.getMessage(); assertNotNull(m); - assertTrue(m.contains("already exists"),m); + assertTrue(m.contains("already exists"),m); } } } @Test - public void testSameDiffGetArrayScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSameDiffGetArrayScalar(Nd4jBackend backend) { final INDArray array = Nd4j.rand(1, 1); final SameDiff sd = SameDiff.create(); final SDVariable a = sd.var("a", array.shape()); @@ -2915,7 +3111,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testVariableRenaming() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariableRenaming(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable v1 = sd.var("x", Nd4j.rand(DataType.FLOAT, 3, 4)); @@ -2937,7 +3135,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testVariableRenaming2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariableRenaming2(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable v1 = sd.placeHolder("x", DataType.FLOAT, 3, 4); @@ -2959,7 +3159,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testPlaceholderShapeValidation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPlaceholderShapeValidation(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable scalar = sd.scalar("scalar", 0.0f); SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4); @@ -3024,7 +3226,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testInferenceWithoutLabel() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInferenceWithoutLabel(Nd4jBackend backend) { //We don't need a value for the label placeholder to calculate most values here SameDiff sd = SameDiff.create(); @@ -3061,7 +3265,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testInferenceWithoutUnnecessaryPlaceholders() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInferenceWithoutUnnecessaryPlaceholders(Nd4jBackend backend) { //We don't need an array for 2 of the placeholders to calculate the SameDiff sd = SameDiff.create(); @@ -3103,7 +3309,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testConvertDTypes1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConvertDTypes1(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable x = sd.var("x", Nd4j.rand(DataType.FLOAT, 3, 4)); @@ -3147,7 +3355,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testConvertDTypes2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConvertDTypes2(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3, 4); @@ -3199,7 +3409,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testGradFnRequiredVars() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGradFnRequiredVars(Nd4jBackend backend) { //User can explicitly request that gradients for specific vars are available when differentiating (creating grad function), // even if they normally wouldn't be needed or calculated @@ -3239,6 +3451,8 @@ public class SameDiffTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIf() throws IOException { SameDiff sd = SameDiff.create(); SDVariable a = sd.placeHolder("a", DataType.DOUBLE); @@ -3266,6 +3480,8 @@ public class SameDiffTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNestedIf() throws IOException { SameDiff sd = SameDiff.create(); SDVariable a = sd.var("a", Nd4j.createFromArray(2.0)); @@ -3289,6 +3505,8 @@ public class SameDiffTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWhile() throws IOException { SameDiff sd = SameDiff.create(); @@ -3337,6 +3555,8 @@ public class SameDiffTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNestedWhileIf() throws IOException { SameDiff sd = SameDiff.create(); SDVariable countIn = sd.constant(5); @@ -3362,7 +3582,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testMod_1(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMod_1(Nd4jBackend backend) { val sd = SameDiff.create(); val initial = sd.constant("initial", Nd4j.createFromArray(5.f, 6.f, 7.f)); val four = sd.constant("four", 4.0f); @@ -3374,7 +3596,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void castShapeTest1(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void castShapeTest1(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable x = sd.constant(Nd4j.createFromArray(1, 2, 3, 4)); SDVariable casted = x.castTo(DataType.FLOAT); @@ -3384,7 +3608,7 @@ public class SameDiffTests extends BaseNd4jTest { @Test @Disabled // casted shape is null - public void castShapeTestEmpty(){ + public void castShapeTestEmpty(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable x = sd.constant(Nd4j.empty(DataType.INT)); SDVariable casted = x.castTo(DataType.FLOAT); @@ -3395,7 +3619,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testEmptyShapeVar(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyShapeVar(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); try { @@ -3416,7 +3642,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testPReLU(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPReLU(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable input = sd.constant(Nd4j.createFromArray( @@ -3431,8 +3659,8 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable out = sd.nn.prelu("out", input, alpha, 2); TestCase tc = new TestCase(sd).expected("out", Nd4j.createFromArray(new double[][][]{{ - {-0.1, 10, 10, -0.1}, - {10, 10, -1, -1} + {-0.1, 10, 10, -0.1}, + {10, 10, -1, -1} }}).castTo(DataType.DOUBLE)).gradientCheck(true); String err = OpValidation.validate(tc); @@ -3440,7 +3668,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSameDiffSeedReproducibilityVarInit() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSameDiffSeedReproducibilityVarInit(Nd4jBackend backend) { SameDiff sd0 = SameDiff.create(); SameDiff sd1 = SameDiff.create(); @@ -3465,7 +3695,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testCalculateGradientsAndOutputs(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCalculateGradientsAndOutputs(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); @@ -3488,9 +3720,11 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(outExp, outs); assertEquals(gExp, g); } - + @Test - public void testConcatVariableGrad() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatVariableGrad(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable label = sd.var("label", DataType.FLOAT, 3, 4); SDVariable a = sd.var("a", DataType.FLOAT, 3, 2); @@ -3510,7 +3744,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSliceVariableGrad() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSliceVariableGrad(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); SDVariable label = sd.var("label", DataType.FLOAT, 3, 4); SDVariable input = sd.var("input", DataType.FLOAT, 3, 4); @@ -3528,7 +3764,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testTrainingConfigJson(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTrainingConfigJson(Nd4jBackend backend) { for(IEvaluation e : new IEvaluation[]{new Evaluation(), new RegressionEvaluation(), new EvaluationBinary(), new ROC(), new ROCMultiClass(), new ROCBinary(), new EvaluationCalibration()}) { TrainingConfig config = new TrainingConfig.Builder() @@ -3544,7 +3782,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testRngSanityCheck(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRngSanityCheck(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); for(DataType dt : new DataType[]{DataType.FLOAT, DataType.DOUBLE,DataType.BFLOAT16}) { if (!dt.isNumerical()) @@ -3559,7 +3799,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testMissingPlaceholderError() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMissingPlaceholderError(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); @@ -3583,7 +3825,9 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testEquals1(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEquals1(Nd4jBackend backend) { SameDiff sd1 = SameDiff.create(); SameDiff sd2 = SameDiff.create(); @@ -3630,7 +3874,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testConv2DWeightsFormat() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv2DWeightsFormat(Nd4jBackend backend) { int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; int oH=2,oW=2; SameDiff sd = SameDiff.create(); @@ -3665,7 +3911,9 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testConv2DDifferentWeightsFormat() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv2DDifferentWeightsFormat(Nd4jBackend backend) { int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; int oH=2,oW=2; SameDiff sd = SameDiff.create(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java index 0673429e0..ef0918eb7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java @@ -30,11 +30,13 @@ import java.util.Map; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.listeners.impl.ScoreListener; import org.nd4j.autodiff.listeners.records.History; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.Evaluation; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -55,14 +57,13 @@ import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.weightinit.impl.XavierInitScheme; @Slf4j -public class SameDiffTrainingTest extends BaseNd4jTest { +public class SameDiffTrainingTest extends BaseNd4jTestWithBackends { - public SameDiffTrainingTest(Nd4jBackend backend) { - super(backend); - } @Test - public void irisTrainingSanityCheck() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void irisTrainingSanityCheck(Nd4jBackend backend) { DataSetIterator iter = new IrisDataSetIterator(150, 150); NormalizerStandardize std = new NormalizerStandardize(); @@ -134,7 +135,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest { @Test - public void irisTrainingEvalTest() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void irisTrainingEvalTest(Nd4jBackend backend) { DataSetIterator iter = new IrisDataSetIterator(150, 150); NormalizerStandardize std = new NormalizerStandardize(); @@ -184,7 +187,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest { @Test - public void irisTrainingValidationTest() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void irisTrainingValidationTest(Nd4jBackend backend) { DataSetIterator iter = new IrisDataSetIterator(150, 150); NormalizerStandardize std = new NormalizerStandardize(); @@ -239,6 +244,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTrainingMixedDtypes(){ for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) { @@ -301,7 +308,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest { } @Test - public void simpleClassification() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void simpleClassification(Nd4jBackend backend) { double learning_rate = 0.001; int seed = 7; org.nd4j.linalg.api.rng.Random rng = Nd4j.getRandom(); @@ -348,6 +357,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTrainingEvalVarNotReqForLoss(){ //If a variable is not required for the loss - normally it won't be calculated //But we want to make sure it IS calculated here - so we can perform evaluation on it diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java index 195d8eb8c..59173bd6b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/CheckpointListenerTest.java @@ -25,11 +25,13 @@ import org.junit.Assert; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -48,11 +50,8 @@ import java.util.concurrent.TimeUnit; import static junit.framework.TestCase.assertTrue; import static org.junit.jupiter.api.Assertions.assertEquals; -public class CheckpointListenerTest extends BaseNd4jTest { +public class CheckpointListenerTest extends BaseNd4jTestWithBackends { - public CheckpointListenerTest(Nd4jBackend backend){ - super(backend); - } @Override public char ordering(){ @@ -96,7 +95,9 @@ public class CheckpointListenerTest extends BaseNd4jTest { @Test - public void testCheckpointEveryEpoch(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception { File dir = testDir.toFile(); SameDiff sd = getModel(); @@ -130,7 +131,9 @@ public class CheckpointListenerTest extends BaseNd4jTest { } @Test - public void testCheckpointEvery5Iter(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception { File dir = testDir.toFile(); SameDiff sd = getModel(); @@ -169,7 +172,9 @@ public class CheckpointListenerTest extends BaseNd4jTest { @Test - public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception { File dir = testDir.toFile(); SameDiff sd = getModel(); @@ -199,7 +204,7 @@ public class CheckpointListenerTest extends BaseNd4jTest { for(File f : files){ String s = f.getAbsolutePath(); // System.out.println(s); - for( int i=0; i>( + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void customEvalTest(Nd4jBackend backend){ + CustomEvaluation accuracyEval = new CustomEvaluation<>( (labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)), CustomEvaluation.mergeConcatenate()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java index 80dc7920a..621cdfa97 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java @@ -21,6 +21,8 @@ package org.nd4j.evaluation; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationCalibration; @@ -29,25 +31,24 @@ import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; -public class EmptyEvaluationTests extends BaseNd4jTest { +public class EmptyEvaluationTests extends BaseNd4jTestWithBackends { - public EmptyEvaluationTests(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { return 'c'; } - @Test - public void testEmptyEvaluation() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyEvaluation (Nd4jBackend backend) { Evaluation e = new Evaluation(); System.out.println(e.stats()); @@ -62,7 +63,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest { } @Test - public void testEmptyRegressionEvaluation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyRegressionEvaluation (Nd4jBackend backend) { RegressionEvaluation re = new RegressionEvaluation(); re.stats(); @@ -76,7 +79,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest { } @Test - public void testEmptyEvaluationBinary() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyEvaluationBinary(Nd4jBackend backend) { EvaluationBinary eb = new EvaluationBinary(); eb.stats(); @@ -91,7 +96,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest { } @Test - public void testEmptyROC() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyROC(Nd4jBackend backend) { ROC roc = new ROC(); roc.stats(); @@ -106,7 +113,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest { } @Test - public void testEmptyROCBinary() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyROCBinary(Nd4jBackend backend) { ROCBinary rb = new ROCBinary(); rb.stats(); @@ -121,7 +130,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest { } @Test - public void testEmptyROCMultiClass() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyROCMultiClass(Nd4jBackend backend) { ROCMultiClass r = new ROCMultiClass(); r.stats(); @@ -136,7 +147,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest { } @Test - public void testEmptyEvaluationCalibration() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyEvaluationCalibration(Nd4jBackend backend) { EvaluationCalibration ec = new EvaluationCalibration(); ec.stats(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java index c40c38678..3b94ee60a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalCustomThreshold.java @@ -21,9 +21,11 @@ package org.nd4j.evaluation; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.EvaluationBinary; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; @@ -36,11 +38,8 @@ import java.util.Random; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -public class EvalCustomThreshold extends BaseNd4jTest { +public class EvalCustomThreshold extends BaseNd4jTestWithBackends { - public EvalCustomThreshold(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -48,7 +47,9 @@ public class EvalCustomThreshold extends BaseNd4jTest { } @Test - public void testEvaluationCustomBinaryThreshold() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationCustomBinaryThreshold(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); //Sanity checks: 0.5 threshold for 1-output and 2-output binary cases @@ -114,7 +115,9 @@ public class EvalCustomThreshold extends BaseNd4jTest { } @Test - public void testEvaluationCostArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationCostArray(Nd4jBackend backend) { int nExamples = 20; @@ -162,7 +165,9 @@ public class EvalCustomThreshold extends BaseNd4jTest { } @Test - public void testEvaluationBinaryCustomThreshold() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationBinaryCustomThreshold(Nd4jBackend backend) { //Sanity check: same results for 0.5 threshold vs. default (no threshold) int nExamples = 20; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java index 0d8ab24ab..ecc0b10f4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java @@ -21,6 +21,8 @@ package org.nd4j.evaluation; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationCalibration; @@ -31,7 +33,7 @@ import org.nd4j.evaluation.curves.Histogram; import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.evaluation.regression.RegressionEvaluation; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; @@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -public class EvalJsonTest extends BaseNd4jTest { +public class EvalJsonTest extends BaseNd4jTestWithBackends { - public EvalJsonTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -54,7 +53,9 @@ public class EvalJsonTest extends BaseNd4jTest { } @Test - public void testSerdeEmpty() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSerdeEmpty(Nd4jBackend backend) { boolean print = false; IEvaluation[] arr = new IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10), @@ -73,8 +74,10 @@ public class EvalJsonTest extends BaseNd4jTest { } } - @Test - public void testSerde() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSerde(Nd4jBackend backend) { boolean print = false; Nd4j.getRandom().setSeed(12345); @@ -121,8 +124,10 @@ public class EvalJsonTest extends BaseNd4jTest { } } - @Test - public void testSerdeExactRoc() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSerdeExactRoc(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); boolean print = false; @@ -199,8 +204,10 @@ public class EvalJsonTest extends BaseNd4jTest { } } - @Test - public void testJsonYamlCurves() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJsonYamlCurves(Nd4jBackend backend) { ROC roc = new ROC(0); INDArray evalLabel = @@ -251,8 +258,10 @@ public class EvalJsonTest extends BaseNd4jTest { } - @Test - public void testJsonWithCustomThreshold() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJsonWithCustomThreshold(Nd4jBackend backend) { //Evaluation - binary threshold Evaluation e = new Evaluation(0.25); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java index 25f606061..d2ec5aff5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalTest.java @@ -21,8 +21,10 @@ package org.nd4j.evaluation; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.evaluation.classification.Evaluation; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -39,11 +41,8 @@ import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; -public class EvalTest extends BaseNd4jTest { +public class EvalTest extends BaseNd4jTestWithBackends { - public EvalTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -52,7 +51,9 @@ public class EvalTest extends BaseNd4jTest { @Test - public void testEval() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEval(Nd4jBackend backend) { int classNum = 5; Evaluation eval = new Evaluation (classNum); @@ -91,7 +92,9 @@ public class EvalTest extends BaseNd4jTest { } @Test - public void testEval2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEval2(Nd4jBackend backend) { DataType dtypeBefore = Nd4j.defaultFloatingPointType(); Evaluation first = null; @@ -150,7 +153,9 @@ public class EvalTest extends BaseNd4jTest { } @Test - public void testStringListLabels() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStringListLabels(Nd4jBackend backend) { INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2); @@ -167,7 +172,9 @@ public class EvalTest extends BaseNd4jTest { } @Test - public void testStringHashLabels() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStringHashLabels(Nd4jBackend backend) { INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2); INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2); @@ -184,7 +191,9 @@ public class EvalTest extends BaseNd4jTest { } @Test - public void testEvalMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvalMasking(Nd4jBackend backend) { int miniBatch = 5; int nOut = 3; int tsLength = 6; @@ -251,7 +260,9 @@ public class EvalTest extends BaseNd4jTest { } @Test - public void testFalsePerfectRecall() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFalsePerfectRecall(Nd4jBackend backend) { int testSize = 100; int numClasses = 5; int winner = 1; @@ -284,7 +295,9 @@ public class EvalTest extends BaseNd4jTest { } @Test - public void testEvaluationMerging() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationMerging(Nd4jBackend backend) { int nRows = 20; int nCols = 3; @@ -358,7 +371,9 @@ public class EvalTest extends BaseNd4jTest { @Test - public void testSingleClassBinaryClassification() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSingleClassBinaryClassification(Nd4jBackend backend) { Evaluation eval = new Evaluation(1); @@ -387,7 +402,9 @@ public class EvalTest extends BaseNd4jTest { } @Test - public void testEvalInvalid() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvalInvalid(Nd4jBackend backend) { Evaluation e = new Evaluation(5); e.eval(0, 1); e.eval(1, 0); @@ -400,7 +417,9 @@ public class EvalTest extends BaseNd4jTest { } @Test - public void testEvalMethods() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvalMethods(Nd4jBackend backend) { //Check eval(int,int) vs. eval(INDArray,INDArray) Evaluation e1 = new Evaluation(4); @@ -443,7 +462,9 @@ public class EvalTest extends BaseNd4jTest { @Test - public void testTopNAccuracy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTopNAccuracy(Nd4jBackend backend) { Evaluation e = new Evaluation(null, 3); @@ -504,7 +525,9 @@ public class EvalTest extends BaseNd4jTest { @Test - public void testTopNAccuracyMerging() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTopNAccuracyMerging(Nd4jBackend backend) { Evaluation e1 = new Evaluation(null, 3); Evaluation e2 = new Evaluation(null, 3); @@ -552,7 +575,9 @@ public class EvalTest extends BaseNd4jTest { } @Test - public void testBinaryCase() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBinaryCase(Nd4jBackend backend) { INDArray ones10 = Nd4j.ones(10, 1); INDArray ones4 = Nd4j.ones(4, 1); INDArray zeros4 = Nd4j.zeros(4, 1); @@ -581,7 +606,9 @@ public class EvalTest extends BaseNd4jTest { } @Test - public void testF1FBeta_MicroMacroAveraging() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testF1FBeta_MicroMacroAveraging(Nd4jBackend backend) { //Confusion matrix: rows = actual, columns = predicted //[3, 1, 0] //[2, 2, 1] @@ -722,7 +749,9 @@ public class EvalTest extends BaseNd4jTest { @Test - public void testConfusionMatrixStats() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConfusionMatrixStats(Nd4jBackend backend) { Evaluation e = new Evaluation(); @@ -743,6 +772,8 @@ public class EvalTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEvalBinaryMetrics(){ Evaluation ePosClass1_nOut2 = new Evaluation(2, 1); @@ -864,6 +895,8 @@ public class EvalTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testConfusionMatrixString(){ Evaluation e = new Evaluation(Arrays.asList("a","b","c")); @@ -914,6 +947,8 @@ public class EvalTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEvaluationNaNs(){ Evaluation e = new Evaluation(); @@ -929,6 +964,8 @@ public class EvalTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSegmentation(){ for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case Nd4j.getRandom().setSeed(12345); @@ -1023,6 +1060,8 @@ public class EvalTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLabelReset(){ Map m = new HashMap<>(); @@ -1056,6 +1095,8 @@ public class EvalTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEvalStatsBinaryCase(){ //Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java index 4bb45f5bb..d82a4fa64 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java @@ -21,9 +21,11 @@ package org.nd4j.evaluation; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.EvaluationBinary; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; @@ -38,11 +40,8 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.nd4j.evaluation.classification.EvaluationBinary.Metric.*; -public class EvaluationBinaryTest extends BaseNd4jTest { +public class EvaluationBinaryTest extends BaseNd4jTestWithBackends { - public EvaluationBinaryTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -50,7 +49,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest { } @Test - public void testEvaluationBinary() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationBinary(Nd4jBackend backend) { //Compare EvaluationBinary to Evaluation class DataType dtypeBefore = Nd4j.defaultFloatingPointType(); EvaluationBinary first = null; @@ -136,7 +137,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest { } @Test - public void testEvaluationBinaryMerging() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationBinaryMerging(Nd4jBackend backend) { int nOut = 4; int[] shape1 = {30, nOut}; int[] shape2 = {50, nOut}; @@ -163,7 +166,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest { } @Test - public void testEvaluationBinaryPerOutputMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationBinaryPerOutputMasking(Nd4jBackend backend) { //Provide a mask array: "ignore" the masked steps @@ -172,7 +177,7 @@ public class EvaluationBinaryTest extends BaseNd4jTest { INDArray labels = Nd4j.create(new double[][] {{1, 1, 1}, {0, 0, 0}, {1, 1, 1}, {0, 1, 1}, {1, 0, 1}}); INDArray predicted = Nd4j.create(new double[][] {{0.9, 0.9, 0.9}, {0.7, 0.7, 0.7}, {0.6, 0.6, 0.6}, - {0.4, 0.4, 0.4}, {0.1, 0.1, 0.1}}); + {0.4, 0.4, 0.4}, {0.1, 0.1, 0.1}}); //Correct? // Y Y m @@ -206,7 +211,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest { } @Test - public void testTimeSeriesEval() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTimeSeriesEval(Nd4jBackend backend) { int[] shape = {2, 4, 3}; Nd4j.getRandom().setSeed(12345); @@ -230,12 +237,14 @@ public class EvaluationBinaryTest extends BaseNd4jTest { } @Test - public void testEvaluationBinaryWithROC() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationBinaryWithROC(Nd4jBackend backend) { //Simple test for nested ROCBinary in EvaluationBinary Nd4j.getRandom().setSeed(12345); INDArray l1 = Nd4j.getExecutioner() - .exec(new BernoulliDistribution(Nd4j.createUninitialized(new int[] {50, 4}), 0.5)); + .exec(new BernoulliDistribution(Nd4j.createUninitialized(new int[] {50, 4}), 0.5)); INDArray p1 = Nd4j.rand(50, 4); EvaluationBinary eb = new EvaluationBinary(4, 30); @@ -247,7 +256,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest { @Test - public void testEvaluationBinary3d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationBinary3d(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); @@ -281,7 +292,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest { } @Test - public void testEvaluationBinary4d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationBinary4d(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); @@ -315,7 +328,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest { } @Test - public void testEvaluationBinary3dMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationBinary3dMasking(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); @@ -376,7 +391,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest { } @Test - public void testEvaluationBinary4dMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationBinary4dMasking(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java index 4bc90e067..2d11b8c22 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java @@ -21,8 +21,10 @@ package org.nd4j.evaluation; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.evaluation.classification.EvaluationCalibration; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; @@ -39,19 +41,18 @@ import java.util.Random; import static org.junit.jupiter.api.Assertions.*; -public class EvaluationCalibrationTest extends BaseNd4jTest { +public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends { - public EvaluationCalibrationTest(Nd4jBackend backend) { - super(backend); - } @Override - public char ordering() { + public char ordering () { return 'c'; } - @Test - public void testReliabilityDiagram() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReliabilityDiagram (Nd4jBackend backend) { DataType dtypeBefore = Nd4j.defaultFloatingPointType(); EvaluationCalibration first = null; @@ -142,8 +143,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest { } } - @Test - public void testLabelAndPredictionCounts() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLabelAndPredictionCounts (Nd4jBackend backend) { int minibatch = 50; int nClasses = 3; @@ -170,8 +173,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest { assertArrayEquals(expPredictionCount, ec.getPredictionCountsEachClass()); } - @Test - public void testResidualPlots() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testResidualPlots (Nd4jBackend backend) { int minibatch = 50; int nClasses = 3; @@ -271,7 +276,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest { } } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSegmentation(){ for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case Nd4j.getRandom().setSeed(12345); @@ -365,8 +372,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest { } } - @Test - public void testEvaluationCalibration3d() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationCalibration3d (Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); @@ -397,8 +406,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest { assertEquals(e2d.stats(), e3d.stats()); } - @Test - public void testEvaluationCalibration3dMasking() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvaluationCalibration3dMasking (Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/NewInstanceTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/NewInstanceTest.java index dd325996d..2e4fee8c9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/NewInstanceTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/NewInstanceTest.java @@ -23,6 +23,8 @@ package org.nd4j.evaluation; import static org.junit.jupiter.api.Assertions.assertEquals; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationCalibration; @@ -30,17 +32,14 @@ import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.regression.RegressionEvaluation; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -public class NewInstanceTest extends BaseNd4jTest { +public class NewInstanceTest extends BaseNd4jTestWithBackends { - public NewInstanceTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -48,7 +47,9 @@ public class NewInstanceTest extends BaseNd4jTest { } @Test - public void testNewInstances() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewInstances(Nd4jBackend backend) { boolean print = true; Nd4j.getRandom().setSeed(12345); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java index 4ccdcda32..a653070a4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java @@ -21,10 +21,12 @@ package org.nd4j.evaluation; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.curves.PrecisionRecallCurve; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; @@ -39,19 +41,17 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; -public class ROCBinaryTest extends BaseNd4jTest { - - public ROCBinaryTest(Nd4jBackend backend) { - super(backend); - } - +public class ROCBinaryTest extends BaseNd4jTestWithBackends { + @Override public char ordering() { return 'c'; } - @Test - public void testROCBinary() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testROCBinary(Nd4jBackend backend) { //Compare ROCBinary to ROC class DataType dtypeBefore = Nd4j.defaultFloatingPointType(); @@ -145,8 +145,10 @@ public class ROCBinaryTest extends BaseNd4jTest { } } - @Test - public void testRocBinaryMerging() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRocBinaryMerging(Nd4jBackend backend) { for (int nSteps : new int[]{30, 0}) { //0 == exact int nOut = 4; int[] shape1 = {30, nOut}; @@ -175,8 +177,10 @@ public class ROCBinaryTest extends BaseNd4jTest { } - @Test - public void testROCBinaryPerOutputMasking() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testROCBinaryPerOutputMasking(Nd4jBackend backend) { for (int nSteps : new int[]{30, 0}) { //0 == exact @@ -215,8 +219,10 @@ public class ROCBinaryTest extends BaseNd4jTest { - @Test - public void testROCBinary3d() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testROCBinary3d(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); @@ -249,8 +255,10 @@ public class ROCBinaryTest extends BaseNd4jTest { } } - @Test - public void testROCBinary4d() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testROCBinary4d(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); @@ -283,8 +291,10 @@ public class ROCBinaryTest extends BaseNd4jTest { } } - @Test - public void testROCBinary3dMasking() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testROCBinary3dMasking(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); @@ -344,8 +354,10 @@ public class ROCBinaryTest extends BaseNd4jTest { } } - @Test - public void testROCBinary4dMasking() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testROCBinary4dMasking(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java index 2333f6f7e..d8a1fecf8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCTest.java @@ -21,12 +21,14 @@ package org.nd4j.evaluation; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROCBinary; import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.RocCurve; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -39,11 +41,8 @@ import java.util.*; import static org.junit.jupiter.api.Assertions.*; -public class ROCTest extends BaseNd4jTest { +public class ROCTest extends BaseNd4jTestWithBackends { - public ROCTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -83,8 +82,10 @@ public class ROCTest extends BaseNd4jTest { expFPR.put(10 / 10.0, 0.0 / totalNegatives); } - @Test - public void testRocBasic() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRocBasic(Nd4jBackend backend) { //2 outputs here - probability distribution over classes (softmax) INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601}, @@ -126,8 +127,10 @@ public class ROCTest extends BaseNd4jTest { assertEquals(1.0, auc, 1e-6); } - @Test - public void testRocBasicSingleClass() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRocBasicSingleClass(Nd4jBackend backend) { //1 output here - single probability value (sigmoid) //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) @@ -164,8 +167,10 @@ public class ROCTest extends BaseNd4jTest { } - @Test - public void testRoc() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRoc(Nd4jBackend backend) { //Previous tests allowed for a perfect classifier with right threshold... INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}}); @@ -249,8 +254,10 @@ public class ROCTest extends BaseNd4jTest { } - @Test - public void testRocTimeSeriesNoMasking() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRocTimeSeriesNoMasking(Nd4jBackend backend) { //Same as first test... //2 outputs here - probability distribution over classes (softmax) @@ -296,8 +303,10 @@ public class ROCTest extends BaseNd4jTest { } } - @Test - public void testRocTimeSeriesMasking() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRocTimeSeriesMasking(Nd4jBackend backend) { //2 outputs here - probability distribution over classes (softmax) INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601}, @@ -346,8 +355,10 @@ public class ROCTest extends BaseNd4jTest { - @Test - public void testCompareRocAndRocMultiClass() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCompareRocAndRocMultiClass(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); //For 2 class case: ROC and Multi-class ROC should be the same... @@ -376,8 +387,10 @@ public class ROCTest extends BaseNd4jTest { } } - @Test - public void testCompare2Vs3Classes() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCompare2Vs3Classes(Nd4jBackend backend) { //ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together... //Both methods implement one vs. all ROC/AUC in different ways @@ -425,8 +438,10 @@ public class ROCTest extends BaseNd4jTest { } } - @Test - public void testROCMerging() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testROCMerging(Nd4jBackend backend) { int nArrays = 10; int minibatch = 64; int nROCs = 3; @@ -470,8 +485,10 @@ public class ROCTest extends BaseNd4jTest { } } - @Test - public void testROCMerging2() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testROCMerging2(Nd4jBackend backend) { int nArrays = 10; int minibatch = 64; int exactAllocBlockSize = 10; @@ -515,8 +532,10 @@ public class ROCTest extends BaseNd4jTest { } - @Test - public void testROCMultiMerging() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testROCMultiMerging(Nd4jBackend backend) { int nArrays = 10; int minibatch = 64; @@ -563,8 +582,10 @@ public class ROCTest extends BaseNd4jTest { } } - @Test - public void testAUCPrecisionRecall() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAUCPrecisionRecall(Nd4jBackend backend) { //Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob //at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0 //at threshold 0.25 to 0.33: tp=2, fp=0, fn=0, tn=1 prec=2/2=1, recall=2/2=1 @@ -610,8 +631,10 @@ public class ROCTest extends BaseNd4jTest { } - @Test - public void testRocAucExact() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRocAucExact(Nd4jBackend backend) { //Check the implementation vs. Scikitlearn /* @@ -773,8 +796,10 @@ public class ROCTest extends BaseNd4jTest { } - @Test - public void rocExactEdgeCaseReallocation() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void rocExactEdgeCaseReallocation(Nd4jBackend backend) { //Set reallocation block size to say 20, but then evaluate a 100-length array @@ -785,8 +810,10 @@ public class ROCTest extends BaseNd4jTest { } - @Test - public void testPrecisionRecallCurveGetPointMethods() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) { double[] threshold = new double[101]; double[] precision = threshold; double[] recall = new double[101]; @@ -821,8 +848,10 @@ public class ROCTest extends BaseNd4jTest { } } - @Test - public void testPrecisionRecallCurveConfusion() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) { //Sanity check: values calculated from the confusion matrix should match the PR curve values for (boolean removeRedundantPts : new boolean[] {true, false}) { @@ -860,7 +889,9 @@ public class ROCTest extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRocMerge(){ Nd4j.getRandom().setSeed(12345); @@ -904,7 +935,9 @@ public class ROCTest extends BaseNd4jTest { assertEquals(auprc, auprcAct, 1e-6); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRocMultiMerge(){ Nd4j.getRandom().setSeed(12345); @@ -953,7 +986,9 @@ public class ROCTest extends BaseNd4jTest { } } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRocBinaryMerge(){ Nd4j.getRandom().setSeed(12345); @@ -998,7 +1033,9 @@ public class ROCTest extends BaseNd4jTest { } } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSegmentationBinary(){ for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case Nd4j.getRandom().setSeed(12345); @@ -1088,7 +1125,9 @@ public class ROCTest extends BaseNd4jTest { } } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSegmentation(){ for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case Nd4j.getRandom().setSeed(12345); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java index f601c53a4..ad373785a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java @@ -21,9 +21,11 @@ package org.nd4j.evaluation; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; @@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; -public class RegressionEvalTest extends BaseNd4jTest { +public class RegressionEvalTest extends BaseNd4jTestWithBackends { - public RegressionEvalTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -52,7 +51,7 @@ public class RegressionEvalTest extends BaseNd4jTest { } @Test() - public void testEvalParameters() { + public void testEvalParameters(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { int specCols = 5; INDArray labels = Nd4j.ones(3); @@ -65,7 +64,9 @@ public class RegressionEvalTest extends BaseNd4jTest { } @Test - public void testPerfectPredictions() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPerfectPredictions(Nd4jBackend backend) { int nCols = 5; int nTestArrays = 100; @@ -92,7 +93,9 @@ public class RegressionEvalTest extends BaseNd4jTest { } @Test - public void testKnownValues() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testKnownValues(Nd4jBackend backend) { DataType dtypeBefore = Nd4j.defaultFloatingPointType(); RegressionEvaluation first = null; @@ -148,7 +151,9 @@ public class RegressionEvalTest extends BaseNd4jTest { @Test - public void testRegressionEvaluationMerging() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRegressionEvaluationMerging(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int nRows = 20; @@ -189,7 +194,9 @@ public class RegressionEvalTest extends BaseNd4jTest { } @Test - public void testRegressionEvalPerOutputMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRegressionEvalPerOutputMasking(Nd4jBackend backend) { INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}}); @@ -216,6 +223,8 @@ public class RegressionEvalTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRegressionEvalTimeSeriesSplit(){ INDArray out1 = Nd4j.rand(new int[]{3, 5, 20}); @@ -238,7 +247,9 @@ public class RegressionEvalTest extends BaseNd4jTest { } @Test - public void testRegressionEval3d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRegressionEval3d(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); @@ -270,7 +281,9 @@ public class RegressionEvalTest extends BaseNd4jTest { } @Test - public void testRegressionEval4d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRegressionEval4d(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); @@ -302,7 +315,9 @@ public class RegressionEvalTest extends BaseNd4jTest { } @Test - public void testRegressionEval3dMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRegressionEval3dMasking(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); @@ -361,7 +376,9 @@ public class RegressionEvalTest extends BaseNd4jTest { } @Test - public void testRegressionEval4dMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRegressionEval4dMasking(Nd4jBackend backend) { INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/TestLegacyJsonLoading.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/TestLegacyJsonLoading.java index 1aaae65d5..e25e6554f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/TestLegacyJsonLoading.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/TestLegacyJsonLoading.java @@ -22,10 +22,12 @@ package org.nd4j.evaluation; import org.apache.commons.io.FileUtils; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.evaluation.regression.RegressionEvaluation; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.common.io.ClassPathResource; @@ -34,11 +36,8 @@ import java.nio.charset.StandardCharsets; import static org.junit.jupiter.api.Assertions.assertEquals; -public class TestLegacyJsonLoading extends BaseNd4jTest { +public class TestLegacyJsonLoading extends BaseNd4jTestWithBackends { - public TestLegacyJsonLoading(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -46,7 +45,9 @@ public class TestLegacyJsonLoading extends BaseNd4jTest { } @Test - public void testEvalLegacyFormat() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEvalLegacyFormat(Nd4jBackend backend) throws Exception { File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile(); String s = FileUtils.readFileToString(f, StandardCharsets.UTF_8); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/AveragingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/AveragingTests.java index ebef6af1d..d38f9107c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/AveragingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/AveragingTests.java @@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -38,17 +39,14 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class AveragingTests extends BaseNd4jTest { + +public class AveragingTests extends BaseNd4jTestWithBackends { private final int THREADS = 16; private final int LENGTH = 51200 * 4; - DataType initialType; + DataType initialType = Nd4j.dataType(); + - public AveragingTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } @BeforeEach public void setUp() { @@ -63,7 +61,9 @@ public class AveragingTests extends BaseNd4jTest { @Test - public void testSingleDeviceAveraging1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSingleDeviceAveraging1(Nd4jBackend backend) { INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0); INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0); INDArray array3 = Nd4j.valueArrayOf(LENGTH, 3.0); @@ -110,7 +110,9 @@ public class AveragingTests extends BaseNd4jTest { } @Test - public void testSingleDeviceAveraging2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSingleDeviceAveraging2(Nd4jBackend backend) { INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH); List arrays = new ArrayList<>(); for (int i = 0; i < THREADS; i++) @@ -127,7 +129,9 @@ public class AveragingTests extends BaseNd4jTest { @Test - public void testAccumulation1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAccumulation1(Nd4jBackend backend) { INDArray array1 = Nd4j.create(100).assign(1.0); INDArray array2 = Nd4j.create(100).assign(2.0); INDArray array3 = Nd4j.create(100).assign(3.0); @@ -140,7 +144,9 @@ public class AveragingTests extends BaseNd4jTest { @Test - public void testAccumulation2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAccumulation2(Nd4jBackend backend) { INDArray array1 = Nd4j.create(100).assign(1.0); INDArray array2 = Nd4j.create(100).assign(2.0); INDArray array3 = Nd4j.create(100).assign(3.0); @@ -155,7 +161,9 @@ public class AveragingTests extends BaseNd4jTest { @Test - public void testAccumulation3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAccumulation3(Nd4jBackend backend) { // we want to ensure that cuda backend is able to launch this op on cpu Nd4j.getAffinityManager().allowCrossDeviceAccess(false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/DataTypeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/DataTypeTest.java index 5f01c8526..78b8f00dc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/DataTypeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/DataTypeTest.java @@ -23,8 +23,9 @@ package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -34,15 +35,14 @@ import java.io.*; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) + @Slf4j -public class DataTypeTest extends BaseNd4jTest { - public DataTypeTest(Nd4jBackend backend) { - super(backend); - } +public class DataTypeTest extends BaseNd4jTestWithBackends { @Test - public void testDataTypes() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDataTypes(Nd4jBackend backend) throws Exception { for (val type : DataType.values()) { if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type)) continue; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/InputValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/InputValidationTests.java index f2c8b5419..f1a296783 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/InputValidationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/InputValidationTests.java @@ -21,20 +21,17 @@ package org.nd4j.linalg; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.fail; -@RunWith(Parameterized.class) -public class InputValidationTests extends BaseNd4jTest { - public InputValidationTests(Nd4jBackend backend) { - super(backend); - } +public class InputValidationTests extends BaseNd4jTestWithBackends { @Override public char ordering() { @@ -45,7 +42,9 @@ public class InputValidationTests extends BaseNd4jTest { ///////////////////// Broadcast Tests /////////////////////// @Test - public void testInvalidColVectorOp1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInvalidColVectorOp1(Nd4jBackend backend) { INDArray first = Nd4j.create(10, 10); INDArray col = Nd4j.create(5, 1); try { @@ -57,7 +56,9 @@ public class InputValidationTests extends BaseNd4jTest { } @Test - public void testInvalidColVectorOp2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInvalidColVectorOp2(Nd4jBackend backend) { INDArray first = Nd4j.create(10, 10); INDArray col = Nd4j.create(5, 1); try { @@ -69,7 +70,9 @@ public class InputValidationTests extends BaseNd4jTest { } @Test - public void testInvalidRowVectorOp1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInvalidRowVectorOp1(Nd4jBackend backend) { INDArray first = Nd4j.create(10, 10); INDArray row = Nd4j.create(1, 5); try { @@ -81,7 +84,9 @@ public class InputValidationTests extends BaseNd4jTest { } @Test - public void testInvalidRowVectorOp2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInvalidRowVectorOp2(Nd4jBackend backend) { INDArray first = Nd4j.create(10, 10); INDArray row = Nd4j.create(1, 5); try { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java index e9175a0cf..d4fa89cf8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java @@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; @@ -47,14 +48,13 @@ import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class LoneTest extends BaseNd4jTest { - public LoneTest(Nd4jBackend backend) { - super(backend); - } + +public class LoneTest extends BaseNd4jTestWithBackends { @Test - public void testSoftmaxStability() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftmaxStability(Nd4jBackend backend) { INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose(); // System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo())); INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1); @@ -68,7 +68,9 @@ public class LoneTest extends BaseNd4jTest { } @Test - public void testFlattenedView() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlattenedView(Nd4jBackend backend) { int rows = 8; int cols = 8; int dim2 = 4; @@ -104,7 +106,9 @@ public class LoneTest extends BaseNd4jTest { } @Test - public void testIndexingColVec() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIndexingColVec(Nd4jBackend backend) { int elements = 5; INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements); INDArray colVector = rowVector.transpose(); @@ -123,7 +127,9 @@ public class LoneTest extends BaseNd4jTest { } @Test - public void concatScalarVectorIssue() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void concatScalarVectorIssue(Nd4jBackend backend) { //A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars INDArray arr1 = Nd4j.create(1, 1); INDArray arr2 = Nd4j.create(1, 8); @@ -133,7 +139,9 @@ public class LoneTest extends BaseNd4jTest { } @Test - public void reshapeTensorMmul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void reshapeTensorMmul(Nd4jBackend backend) { INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2); INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2); int[][] axes = new int[2][]; @@ -145,7 +153,9 @@ public class LoneTest extends BaseNd4jTest { } @Test - public void maskWhenMerge() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void maskWhenMerge(Nd4jBackend backend) { DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5)); DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3)); List dataSetList = new ArrayList(); @@ -160,7 +170,9 @@ public class LoneTest extends BaseNd4jTest { } @Test - public void testRelu() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRelu(Nd4jBackend backend) { INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4); INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4); INDArray b = Nd4j.getExecutioner().exec(new Tanh(aA)); @@ -172,7 +184,7 @@ public class LoneTest extends BaseNd4jTest { @Test //broken at a threshold - public void testArgMax() { + public void testArgMax(Nd4jBackend backend) { int max = 63; INDArray A = Nd4j.linspace(1, max, max).reshape(1, max); int currentArgMax = Nd4j.argMax(A).getInt(0); @@ -186,7 +198,9 @@ public class LoneTest extends BaseNd4jTest { } @Test - public void testRPF() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRPF(Nd4jBackend backend) { val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3); log.info("--------"); @@ -199,7 +213,9 @@ public class LoneTest extends BaseNd4jTest { } @Test - public void testConcat3D_Vstack_C() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat3D_Vstack_C(Nd4jBackend backend) { val shape = new long[]{1, 1000, 20}; List cArrays = new ArrayList<>(); @@ -229,7 +245,9 @@ public class LoneTest extends BaseNd4jTest { @Test - public void testGetRow1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRow1(Nd4jBackend backend) { INDArray array = Nd4j.create(10000, 10000); //Thread.sleep(10000); @@ -256,7 +274,7 @@ public class LoneTest extends BaseNd4jTest { } @Test() - public void checkIllegalElementOps() { + public void checkIllegalElementOps(Nd4jBackend backend) { assertThrows(Exception.class,() -> { INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5); INDArray B = A.dup().reshape(2, 2, 5); @@ -268,7 +286,9 @@ public class LoneTest extends BaseNd4jTest { } @Test - public void checkSliceofSlice() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void checkSliceofSlice(Nd4jBackend backend) { /* Issue 1: Slice of slice with c order and f order views are not equal @@ -308,7 +328,9 @@ public class LoneTest extends BaseNd4jTest { } @Test - public void checkWithReshape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void checkWithReshape(Nd4jBackend backend) { INDArray arr = Nd4j.create(1, 3); INDArray reshaped = arr.reshape('f', 3, 1); for (int i=0;i> list = new ArrayList<>(100); for (int i = 0; i < 100; i++) { - Future future = ex.submit(new Runnable() { - @Override - public void run() { - INDArray dot = Nd4j.linspace(1, 8, 8, DataType.DOUBLE); + Future future = ex.submit(() -> { + INDArray dot = Nd4j.linspace(1, 8, 8, DataType.DOUBLE); // System.out.println(Transforms.sigmoid(dot)); - Transforms.sigmoid(dot); - } + Transforms.sigmoid(dot); }); list.add(future); } @@ -191,7 +196,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testBroadcastingGenerated() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastingGenerated(Nd4jBackend backend) { int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10); List>> broadCastList = new ArrayList<>(broadcastShape.length); for (int[] shape : broadcastShape) { @@ -206,7 +213,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest { INDArray inputArrBroadcast = val.getFirst(); val destShape = NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7); INDArray output = inputArrBroadcast - .broadcast(NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7)); + .broadcast(NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7)); assertArrayEquals(destShape, output.shape()); } } @@ -216,7 +223,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testBroadCasting() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadCasting(Nd4jBackend backend) { INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE); INDArray ret = first.broadcast(3, 4); INDArray testRet = Nd4j.create(new double[][] {{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}}); @@ -229,14 +238,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testOneTensor() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOneTensor(Nd4jBackend backend) { INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1); INDArray matrixToBroadcast = Nd4j.ones(1, 1); assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr); } @Test - public void testSortWithIndicesDescending() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSortWithIndicesDescending(Nd4jBackend backend) { INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); //indices,data INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, false); @@ -247,7 +260,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testSortDeadlock() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSortDeadlock(Nd4jBackend backend) { val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768); val sorted = Nd4j.sort(toSort.dup(), 1, false); @@ -255,7 +270,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testSortWithIndices() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSortWithIndices(Nd4jBackend backend) { INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); //indices,data INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, true); @@ -266,14 +283,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testNd4jSortScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNd4jSortScalar(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1); INDArray sorted = Nd4j.sort(linspace, 1, false); // System.out.println(sorted); } @Test - public void testSwapAxesFortranOrder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSwapAxesFortranOrder(Nd4jBackend backend) { INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE); for (int i = 0; i < n.slices(); i++) { INDArray nSlice = n.slice(i); @@ -292,7 +313,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testDimShuffle() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDimShuffle(Nd4jBackend backend) { INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false}); assertTrue(Arrays.equals(new long[] {2, 1, 2}, twoOneTwo.shape())); @@ -303,7 +326,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testGetVsGetScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetVsGetScalar(Nd4jBackend backend) { INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); float element = a.getFloat(0, 1); double element2 = a.getDouble(0, 1); @@ -316,7 +341,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testDivide() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDivide(Nd4jBackend backend) { INDArray two = Nd4j.create(new float[] {2, 2, 2, 2}); INDArray div = two.div(two); assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage()); @@ -330,7 +357,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testSigmoid() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSigmoid(Nd4jBackend backend) { INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f}); INDArray sigmoid = Transforms.sigmoid(n, false); @@ -339,7 +368,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testNeg() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNeg(Nd4jBackend backend) { INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4}); INDArray neg = Transforms.neg(n); @@ -349,7 +380,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testCosineSim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCosineSim(Nd4jBackend backend) { INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4}); double sim = Transforms.cosineSim(vec1, vec2); @@ -364,7 +397,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testExp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExp(Nd4jBackend backend) { INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f}); INDArray exped = Transforms.exp(n); @@ -374,7 +409,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalar(Nd4jBackend backend) { INDArray a = Nd4j.scalar(1.0f); assertEquals(true, a.isScalar()); @@ -386,7 +423,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testWrap() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testWrap(Nd4jBackend backend) { int[] shape = {2, 4}; INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]); INDArray n = d; @@ -411,7 +450,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testGetRowFortran() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRowFortran(Nd4jBackend backend) { INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2}); INDArray column = Nd4j.create(new float[] {1, 3}); INDArray column2 = Nd4j.create(new float[] {2, 4}); @@ -424,7 +465,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testGetColumnFortran() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetColumnFortran(Nd4jBackend backend) { INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2}); INDArray column = Nd4j.create(new double[] {1, 2}); INDArray column2 = Nd4j.create(new double[] {3, 4}); @@ -438,7 +481,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testGetColumns() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetColumns(Nd4jBackend backend) { INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE); // log.info("Original: {}", matrix); INDArray matrixGet = matrix.getColumns(1, 2); @@ -452,7 +497,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testVectorInit() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorInit(Nd4jBackend backend) { DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(); INDArray arr = Nd4j.create(data, new long[] {1, 4}); assertEquals(true, arr.isRowVector()); @@ -465,7 +512,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testAssignOffset() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssignOffset(Nd4jBackend backend) { INDArray arr = Nd4j.ones(5, 5); INDArray row = arr.slice(1); row.assign(1); @@ -473,7 +522,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testColumns() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testColumns(Nd4jBackend backend) { INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE); INDArray column = Nd4j.create(new double[] {1, 2, 3}); arr.putColumn(0, column); @@ -511,7 +562,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testPutRow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutRow(Nd4jBackend backend) { INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray n = d.dup(); @@ -570,7 +623,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testInplaceTranspose() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInplaceTranspose(Nd4jBackend backend) { INDArray test = Nd4j.rand(3, 4); INDArray orig = test.dup(); INDArray transposei = test.transposei(); @@ -585,7 +640,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testMmulF() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmulF(Nd4jBackend backend) { DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data(); INDArray n = Nd4j.create(data, new long[] {1, 10}); @@ -603,7 +660,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testRowsColumns() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRowsColumns(Nd4jBackend backend) { DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data(); INDArray rows = Nd4j.create(data, new long[] {2, 3}); assertEquals(2, rows.rows()); @@ -619,7 +678,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testTranspose() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTranspose(Nd4jBackend backend) { INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4}); INDArray transpose = n.transpose(); assertEquals(n.length(), transpose.length()); @@ -647,7 +708,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testAddMatrix() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddMatrix(Nd4jBackend backend) { INDArray five = Nd4j.ones(5); five.addi(five.dup()); INDArray twos = Nd4j.valueArrayOf(5, 2); @@ -658,7 +721,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testMMul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMMul(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}}); @@ -669,7 +734,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testPutSlice() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutSlice(Nd4jBackend backend) { INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3); INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3); Nd4j.exec(new PrintVariable(newSlice)); @@ -680,7 +747,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testRowVectorMultipleIndices() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRowVectorMultipleIndices(Nd4jBackend backend) { INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4); linear.putScalar(new long[] {0, 1}, 1); assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage()); @@ -689,7 +758,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testDim1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDim1(Nd4jBackend backend) { INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1); INDArray same = sum.dup(); assertEquals(same.sum(1), sum.reshape(2)); @@ -697,7 +768,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testEps() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEps(Nd4jBackend backend) { val ones = Nd4j.ones(5); val res = Nd4j.createUninitialized(DataType.BOOL, 5); assertTrue(Nd4j.getExecutioner().exec(new Eps(ones, ones, res)).all()); @@ -705,7 +778,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testLogDouble() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogDouble(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE); INDArray log = Transforms.log(linspace); INDArray assertion = Nd4j.create(new double[] {0, 0.6931471805599453, 1.0986122886681098, 1.3862943611198906, 1.6094379124341005, 1.791759469228055}); @@ -713,28 +788,36 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testVectorSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorSum(Nd4jBackend backend) { INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1); } @Test - public void testVectorSum2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorSum2(Nd4jBackend backend) { INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4}); assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1); } @Test - public void testVectorSum3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorSum3(Nd4jBackend backend) { INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4}); assertEquals(lin, lin2); } @Test - public void testSmallSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSmallSum(Nd4jBackend backend) { INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007}); base.addi(1e-12); INDArray assertion = Nd4j.create(new double[] {5.84333433, 3.054001}); @@ -745,7 +828,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testPermute() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPermute(Nd4jBackend backend) { INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4}); INDArray transpose = n.transpose(); INDArray permute = n.permute(1, 0); @@ -774,7 +859,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testAppendBias() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAppendBias(Nd4jBackend backend) { INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose(); INDArray test = Nd4j.appendBias(rand); INDArray assertion = Nd4j.toFlattened(rand, Nd4j.scalar(DataType.DOUBLE, 1.0)).reshape(-1, 1); @@ -782,7 +869,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testRand() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRand(Nd4jBackend backend) { INDArray rand = Nd4j.randn(5, 5); Nd4j.getDistributions().createUniform(0.4, 4).sample(5); Nd4j.getDistributions().createNormal(1, 5).sample(10); @@ -794,7 +883,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testIdentity() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIdentity(Nd4jBackend backend) { INDArray eye = Nd4j.eye(5); assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape())); eye = Nd4j.eye(5); @@ -805,7 +896,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testColumnVectorOpsFortran() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testColumnVectorOpsFortran(Nd4jBackend backend) { INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1}); twoByTwo.addiColumnVector(toAdd); @@ -816,7 +909,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testRSubi() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRSubi(Nd4jBackend backend) { INDArray n2 = Nd4j.ones(2); INDArray n2Assertion = Nd4j.zeros(2); INDArray nRsubi = n2.rsubi(1); @@ -826,7 +921,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testAssign() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssign(Nd4jBackend backend) { INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); vector.assign(1); assertEquals(Nd4j.ones(5).castTo(DataType.DOUBLE), vector); @@ -843,7 +940,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testAddScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddScalar(Nd4jBackend backend) { INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0); INDArray rdiv = div.add(1); INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 5.0); @@ -851,7 +950,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testRdivScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRdivScalar(Nd4jBackend backend) { INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0); INDArray rdiv = div.rdiv(1); INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 0.25); @@ -859,7 +960,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testRDivi() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRDivi(Nd4jBackend backend) { INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0); INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5); INDArray nRsubi = n2.rdivi(2); @@ -869,7 +972,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testNumVectorsAlongDimension() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNumVectorsAlongDimension(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); assertEquals(12, arr.vectorsAlongDimension(2)); } @@ -877,7 +982,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testBroadCast() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadCast(Nd4jBackend backend) { INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray broadCasted = n.broadcast(5, 4); for (int i = 0; i < broadCasted.rows(); i++) { @@ -899,7 +1006,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testMatrix() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatrix(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2}); INDArray row = arr.getRow(0); @@ -909,7 +1018,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testPutRowGetRowOrdering() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutRowGetRowOrdering(Nd4jBackend backend) { INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray put = Nd4j.create(new double[] {5, 6}); row1.putRow(1, put); @@ -931,7 +1042,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testSumWithRow1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSumWithRow1(Nd4jBackend backend) { //Works: INDArray array2d = Nd4j.ones(1, 10); array2d.sum(0); //OK @@ -962,7 +1075,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testSumWithRow2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSumWithRow2(Nd4jBackend backend) { //All sums in this method execute without exceptions. INDArray array3d = Nd4j.ones(2, 10, 10); array3d.sum(0); @@ -985,7 +1100,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testPutRowFortran() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutRowFortran(Nd4jBackend backend) { INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE); INDArray put = Nd4j.create(new double[] {5, 6}); row1.putRow(1, put); @@ -998,7 +1115,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testElementWiseOps() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testElementWiseOps(Nd4jBackend backend) { INDArray n1 = Nd4j.scalar(1); INDArray n2 = Nd4j.scalar(2); INDArray nClone = n1.add(n2); @@ -1021,7 +1140,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test - public void testRollAxis() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRollAxis(Nd4jBackend backend) { INDArray toRoll = Nd4j.ones(3, 4, 5, 6); assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape()); val shape = Nd4j.rollAxis(toRoll, 3).shape(); @@ -1030,20 +1151,22 @@ public class NDArrayTestsFortran extends BaseNd4jTest { @Test @Disabled - public void testTensorDot() { + public void testTensorDot(Nd4jBackend backend) { INDArray oneThroughSixty = Nd4j.arange(60).reshape('f', 3, 4, 5).castTo(DataType.DOUBLE); INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape('f', 4, 3, 2).castTo(DataType.DOUBLE); INDArray result = Nd4j.tensorMmul(oneThroughSixty, oneThroughTwentyFour, new int[][] {{1, 0}, {0, 1}}); assertArrayEquals(new long[] {5, 2}, result.shape()); INDArray assertion = Nd4j.create(new double[][] {{440., 1232.}, {1232., 3752.}, {2024., 6272.}, {2816., 8792.}, - {3608., 11312.}}); + {3608., 11312.}}); assertEquals(assertion, result); } @Test - public void testNegativeShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNegativeShape(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray reshaped = linspace.reshape(-1, 2); assertArrayEquals(new long[] {2, 2}, reshaped.shape()); @@ -1055,7 +1178,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testGetColumnGetRow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetColumnGetRow(Nd4jBackend backend) { INDArray row = Nd4j.ones(1, 5); for (int i = 0; i < 5; i++) { INDArray col = row.getColumn(i); @@ -1070,7 +1195,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testDupAndDupWithOrder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDupAndDupWithOrder(Nd4jBackend backend) { List> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE); int count = 0; for (Pair pair : testInputs) { @@ -1092,7 +1219,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest { } @Test - public void testToOffsetZeroCopy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToOffsetZeroCopy(Nd4jBackend backend) { List> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE); int cnt = 0; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 660ef4a8e..9c7482b0b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -33,14 +33,14 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.primitives.Pair; import org.nd4j.common.util.ArrayUtil; import org.nd4j.common.util.MathUtils; import org.nd4j.enums.WeightsFormat; -import org.nd4j.imports.tfgraphs.NodeReader; import org.nd4j.linalg.api.blas.Level1; import org.nd4j.linalg.api.blas.params.GemmParams; import org.nd4j.linalg.api.blas.params.MMulTranspose; @@ -151,18 +151,12 @@ import static org.junit.jupiter.api.Assertions.*; * @author Adam Gibson */ @Slf4j -@RunWith(Parameterized.class) -public class Nd4jTestsC extends BaseNd4jTest { - DataType initialType; - Level1 l1; +public class Nd4jTestsC extends BaseNd4jTestWithBackends { + DataType initialType = Nd4j.dataType(); + Level1 l1 = Nd4j.getBlasWrapper().level1(); - public Nd4jTestsC(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - l1 = Nd4j.getBlasWrapper().level1(); - } @Override public long getTimeoutMilliseconds() { @@ -183,14 +177,18 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testArangeNegative() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArangeNegative(Nd4jBackend backend) { INDArray arr = Nd4j.arange(-2,2).castTo(DataType.DOUBLE); INDArray assertion = Nd4j.create(new double[]{-2, -1, 0, 1}); assertEquals(assertion,arr); } @Test - public void testTri() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTri(Nd4jBackend backend) { INDArray assertion = Nd4j.create(new double[][]{ {1,1,1,0,0}, {1,1,1,1,0}, @@ -203,7 +201,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testTriu() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTriu(Nd4jBackend backend) { INDArray input = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(4,3); int k = -1; INDArray test = Nd4j.triu(input,k); @@ -218,13 +218,17 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testDiag() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDiag(Nd4jBackend backend) { INDArray diag = Nd4j.diag(Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(4,1)); assertArrayEquals(new long[] {4,4},diag.shape()); } @Test - public void testGetRowEdgeCase() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRowEdgeCase(Nd4jBackend backend) { INDArray orig = Nd4j.linspace(1,300,300, DataType.DOUBLE).reshape('c', 100, 3); INDArray col = orig.getColumn(0).reshape(100, 1); @@ -244,7 +248,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNd4jEnvironment() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNd4jEnvironment(Nd4jBackend backend) { System.out.println(Nd4j.getExecutioner().getEnvironmentInformation()); int manualNumCores = Integer.parseInt(Nd4j.getExecutioner().getEnvironmentInformation() .get(Nd4jEnvironment.CPU_CORES_KEY).toString()); @@ -254,6 +260,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSerialization(@TempDir Path testDir) throws Exception { Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.rand(1, 20); @@ -278,7 +286,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTensorAlongDimension2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTensorAlongDimension2(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[100], new long[] {50, 1, 2}); assertArrayEquals(new long[] {1, 2}, array.slice(0, 0).shape()); @@ -286,7 +296,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Disabled // with broadcastables mechanic it'll be ok @Test - public void testShapeEqualsOnElementWise() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShapeEqualsOnElementWise(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { Nd4j.ones(10000, 1).sub(Nd4j.ones(1, 2)); @@ -294,7 +306,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIsMaxVectorCase() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsMaxVectorCase(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[] {1, 2, 4, 3}, new long[] {2, 2}); INDArray assertion = Nd4j.create(new boolean[] {false, false, true, false}, new long[] {2, 2}, DataType.BOOL); INDArray test = Nd4j.getExecutioner().exec(new IsMax(arr))[0]; @@ -302,7 +316,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testArgMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArgMax(Nd4jBackend backend) { INDArray toArgMax = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); INDArray argMaxZero = Nd4j.argMax(toArgMax, 0); INDArray argMax = Nd4j.argMax(toArgMax, 1); @@ -317,7 +333,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testArgMax_119() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArgMax_119(Nd4jBackend backend) { val array = Nd4j.create(new double[]{1, 2, 119, 2}); val max = array.argMax(); @@ -326,7 +344,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAutoBroadcastShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAutoBroadcastShape(Nd4jBackend backend) { val assertion = new long[]{2,2,2,5}; val shapeTest = Shape.broadcastOutputShape(new long[]{2,1,2,1},new long[]{2,1,5}); assertArrayEquals(assertion,shapeTest); @@ -334,7 +354,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test @Disabled //temporary till libnd4j implements general broadcasting - public void testAutoBroadcastAdd() { + public void testAutoBroadcastAdd(Nd4jBackend backend) { INDArray left = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,1,2,1); INDArray right = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,1,5); INDArray assertion = Nd4j.create(new double[]{2,3,4,5,6,3,4,5,6,7,7,8,9,10,11,8,9,10,11,12,4,5,6,7,8,5,6,7,8,9,9,10,11,12,13,10,11,12,13,14}).reshape(2,2,2,5); @@ -343,7 +363,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAudoBroadcastAddMatrix() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAudoBroadcastAddMatrix(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,2); INDArray row = Nd4j.ones(1, 2); INDArray assertion = arr.add(1.0); @@ -352,7 +374,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testScalarOps() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarOps(Nd4jBackend backend) { INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3}); assertEquals(27d, n.length(), 1e-1); n.addi(Nd4j.scalar(1d)); @@ -368,7 +392,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTensorAlongDimension() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTensorAlongDimension(Nd4jBackend backend) { val shape = new long[] {4, 5, 7}; int length = ArrayUtil.prod(shape); INDArray arr = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape(shape); @@ -392,7 +418,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testMmulWithTranspose() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmulWithTranspose(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,2); INDArray arr2 = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,2).transpose(); INDArray arrTransposeAssertion = arr.transpose().mmul(arr2); @@ -415,7 +443,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testGetDouble() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetDouble(Nd4jBackend backend) { INDArray n2 = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}); INDArray swapped = n2.swapAxes(n2.shape().length - 1, 1); INDArray slice0 = swapped.slice(0).slice(1); @@ -424,6 +454,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWriteTxt() throws Exception { INDArray row = Nd4j.create(new double[][] {{1, 2}, {3, 4}}); ByteArrayOutputStream bos = new ByteArrayOutputStream(); @@ -435,7 +467,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void test2dMatrixOrderingSwitch() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test2dMatrixOrderingSwitch(Nd4jBackend backend) { char order = Nd4j.order(); INDArray c = Nd4j.create(new double[][] {{1, 2}, {3, 4}}, 'c'); assertEquals('c', c.ordering()); @@ -446,7 +480,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatrix() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatrix(Nd4jBackend backend) { INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray brr = Nd4j.create(new float[] {5, 6}, new long[] {2}); INDArray row = arr.getRow(0); @@ -456,7 +492,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMMul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMMul(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}}); @@ -483,7 +521,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSubiRowVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSubiRowVector(Nd4jBackend backend) { INDArray oneThroughFour = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape('c', 2, 2); INDArray row1 = oneThroughFour.getRow(1).dup(); oneThroughFour.subiRowVector(row1); @@ -494,7 +534,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testAddiRowVectorWithScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddiRowVectorWithScalar(Nd4jBackend backend) { INDArray colVector = Nd4j.create(5, 1).assign(0.0); INDArray scalar = Nd4j.create(1, 1).assign(0.0); scalar.putScalar(0, 1); @@ -507,7 +549,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTADOnVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTADOnVector(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray rowVec = Nd4j.rand(1, 10); @@ -532,7 +576,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testLength() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLength(Nd4jBackend backend) { INDArray values = Nd4j.create(2, 2); INDArray values2 = Nd4j.create(2, 2); @@ -556,7 +602,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadCasting() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadCasting(Nd4jBackend backend) { INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE); INDArray ret = first.broadcast(3, 4); INDArray testRet = Nd4j.create(new double[][] {{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}}); @@ -569,7 +617,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testGetColumns() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetColumns(Nd4jBackend backend) { INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray matrixGet = matrix.getColumns(1, 2); INDArray matrixAssertion = Nd4j.create(new double[][] {{2, 3}, {5, 6}}); @@ -577,7 +627,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSort() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSort(Nd4jBackend backend) { INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray ascending = Nd4j.sort(toSort.dup(), 1, true); //rows are already sorted @@ -589,7 +641,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSortRows() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSortRows(Nd4jBackend backend) { int nRows = 10; int nCols = 5; java.util.Random r = new java.util.Random(12345); @@ -623,7 +677,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testToFlattenedOrder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToFlattenedOrder(Nd4jBackend backend) { INDArray concatC = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape('c', 2, 2); INDArray concatF = Nd4j.create(new long[] {2, 2}, 'f'); concatF.assign(concatC); @@ -638,7 +694,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testZero() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testZero(Nd4jBackend backend) { Nd4j.ones(11).sumNumber(); Nd4j.ones(12).sumNumber(); Nd4j.ones(2).sumNumber(); @@ -646,7 +704,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSumNumberRepeatability() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSumNumberRepeatability(Nd4jBackend backend) { INDArray arr = Nd4j.ones(1, 450).reshape('c', 150, 3); double first = arr.sumNumber().doubleValue(); @@ -660,7 +720,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testToFlattened2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToFlattened2(Nd4jBackend backend) { int rows = 3; int cols = 4; int dim2 = 5; @@ -701,7 +763,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testToFlattenedOnViews() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToFlattenedOnViews(Nd4jBackend backend) { int rows = 8; int cols = 8; int dim2 = 4; @@ -749,7 +813,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testIsMax2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsMax2(Nd4jBackend backend) { //Tests: full buffer... //1d INDArray arr1 = Nd4j.create(new double[] {1, 2, 3, 1}); @@ -777,7 +843,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testToFlattened3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToFlattened3(Nd4jBackend backend) { INDArray inC1 = Nd4j.create(new long[] {10, 100}, 'c'); INDArray inC2 = Nd4j.create(new long[] {1, 100}, 'c'); @@ -799,7 +867,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIsMaxEqualValues() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsMaxEqualValues(Nd4jBackend backend) { //Assumption here: should only have a 1 for *first* maximum value, if multiple values are exactly equal //[1 1 1] -> [1 0 0] @@ -814,28 +884,36 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIMaxVector_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMaxVector_1(Nd4jBackend backend) { val array = Nd4j.ones(3); val idx = array.argMax(0).getInt(0); assertEquals(0, idx); } @Test - public void testIMaxVector_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMaxVector_2(Nd4jBackend backend) { val array = Nd4j.ones(3); val idx = array.argMax(Integer.MAX_VALUE).getInt(0); assertEquals(0, idx); } @Test - public void testIMaxVector_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMaxVector_3(Nd4jBackend backend) { val array = Nd4j.ones(3); val idx = array.argMax().getInt(0); assertEquals(0, idx); } @Test - public void testIsMaxEqualValues_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsMaxEqualValues_2(Nd4jBackend backend) { //[0 2] [0 1] //[2 1] -> [0 0]bg INDArray orig = Nd4j.create(new double[][] {{0, 3}, {2, 1}}); @@ -851,7 +929,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIsMaxEqualValues_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsMaxEqualValues_3(Nd4jBackend backend) { //[0 2] [0 1] //[2 1] -> [0 0] INDArray orig = Nd4j.create(new double[][] {{0, 2}, {3, 1}}); @@ -864,7 +944,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSqrt_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSqrt_1(Nd4jBackend backend) { val x = Nd4j.createFromArray(9.0, 9.0, 9.0, 9.0); val x2 = Nd4j.createFromArray(9.0, 9.0, 9.0, 9.0); val e = Nd4j.createFromArray(3.0, 3.0, 3.0, 3.0); @@ -880,7 +962,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAssign_CF() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssign_CF(Nd4jBackend backend) { val orig = Nd4j.create(new double[][] {{0, 2}, {2, 1}}); val oc = orig.dup('c'); val of = orig.dup('f'); @@ -890,7 +974,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIsMaxAlongDimension() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsMaxAlongDimension(Nd4jBackend backend) { //1d: row vector INDArray orig = Nd4j.create(new double[] {1, 2, 3, 1}).reshape(1,4 ); @@ -959,7 +1045,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIMaxSingleDim1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMaxSingleDim1(Nd4jBackend backend) { INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}}); INDArray result = Nd4j.argMax(orig2d.dup('c'), 0); @@ -968,7 +1056,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIsMaxSingleDim1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsMaxSingleDim1(Nd4jBackend backend) { INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}}); INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0))[0]; INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}}); @@ -981,7 +1071,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcastRepeated() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastRepeated(Nd4jBackend backend) { INDArray z = Nd4j.create(1, 4, 4, 3); INDArray bias = Nd4j.create(1, 3); BroadcastOp op = new BroadcastAddOp(z, bias, z, 3); @@ -999,7 +1091,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testVStackDifferentOrders() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVStackDifferentOrders(Nd4jBackend backend) { INDArray expected = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); for (char order : new char[] {'c', 'f'}) { @@ -1022,7 +1116,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testVStackEdgeCase() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVStackEdgeCase(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray vstacked = Nd4j.vstack(arr); assertEquals(arr.reshape(1,4), vstacked); @@ -1030,7 +1126,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testEps3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEps3(Nd4jBackend backend) { INDArray first = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); INDArray second = Nd4j.linspace(20, 30, 10, DataType.DOUBLE); @@ -1049,7 +1147,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test @Disabled - public void testSumAlongDim1sEdgeCases() { + public void testSumAlongDim1sEdgeCases(Nd4jBackend backend) { val shapes = new long[][] { //Standard case: {2, 2, 3, 4}, @@ -1105,7 +1203,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIsMaxAlongDimensionSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsMaxAlongDimensionSimple(Nd4jBackend backend) { //Simple test: when doing IsMax along a dimension, we expect all values to be either 0 or 1 //Do IsMax along dims 0&1 for rank 2, along 0,1&2 for rank 3, etc @@ -1141,7 +1241,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSortColumns() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSortColumns(Nd4jBackend backend) { int nRows = 5; int nCols = 10; java.util.Random r = new java.util.Random(12345); @@ -1173,7 +1275,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testAddVectorWithOffset() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddVectorWithOffset(Nd4jBackend backend) { INDArray oneThroughFour = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray row1 = oneThroughFour.getRow(1); row1.addi(1); @@ -1186,7 +1290,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testLinearViewGetAndPut() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLinearViewGetAndPut(Nd4jBackend backend) { INDArray test = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray linear = test.reshape(-1); linear.putScalar(2, 6); @@ -1198,7 +1304,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testRowVectorGemm() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRowVectorGemm(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, 4); INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4); INDArray result = linspace.mmul(other); @@ -1207,6 +1315,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGemmStrided(){ for( val x : new int[]{5, 1}) { @@ -1239,7 +1349,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testMultiSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultiSum(Nd4jBackend backend) { /** * ([[[ 0., 1.], [ 2., 3.]], @@ -1290,7 +1402,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSum2dv2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum2dv2(Nd4jBackend backend) { INDArray in = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape('c', 2, 2, 2); val dims = new int[][] {{0, 1}, {1, 0}, {0, 2}, {2, 0}, {1, 2}, {2, 1}}; @@ -1311,7 +1425,9 @@ public class Nd4jTestsC extends BaseNd4jTest { //Passes on 3.9: @Test - public void testSum3Of4_2222() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum3Of4_2222(Nd4jBackend backend) { int[] shape = {2, 2, 2, 2}; int length = ArrayUtil.prod(shape); INDArray arrC = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape); @@ -1335,7 +1451,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcast1d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcast1d(Nd4jBackend backend) { int[] shape = {4, 3, 2}; int[] toBroadcastDims = new int[] {0, 1, 2}; int[][] toBroadcastShapes = new int[][] {{1, 4}, {1, 3}, {1, 2}}; @@ -1392,7 +1510,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSum3Of4_3322() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum3Of4_3322(Nd4jBackend backend) { int[] shape = {3, 3, 2, 2}; int length = ArrayUtil.prod(shape); INDArray arrC = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape); @@ -1416,7 +1536,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testToFlattened() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToFlattened(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); List concat = new ArrayList<>(); for (int i = 0; i < 3; i++) { @@ -1431,7 +1553,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testDup() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDup(Nd4jBackend backend) { for (int x = 0; x < 100; x++) { INDArray orig = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray dup = orig.dup(); @@ -1453,7 +1577,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSortWithIndicesDescending() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSortWithIndicesDescending(Nd4jBackend backend) { INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); //indices,data INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, false); @@ -1466,14 +1592,18 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testGetFromRowVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetFromRowVector(Nd4jBackend backend) { INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray rowGet = matrix.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 2)); assertArrayEquals(new long[] {2}, rowGet.shape()); } @Test - public void testSubRowVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSubRowVector(Nd4jBackend backend) { INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray row = Nd4j.linspace(1, 3, 3, DataType.DOUBLE); INDArray test = matrix.subRowVector(row); @@ -1492,7 +1622,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testDimShuffle() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDimShuffle(Nd4jBackend backend) { INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false}); assertTrue(Arrays.equals(new long[] {2, 1, 2}, twoOneTwo.shape())); @@ -1503,7 +1635,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testGetVsGetScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetVsGetScalar(Nd4jBackend backend) { INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); float element = a.getFloat(0, 1); double element2 = a.getDouble(0, 1); @@ -1516,7 +1650,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testDivide() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDivide(Nd4jBackend backend) { INDArray two = Nd4j.create(new double[] {2, 2, 2, 2}); INDArray div = two.div(two); assertEquals(Nd4j.ones(4), div); @@ -1530,7 +1666,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSigmoid() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSigmoid(Nd4jBackend backend) { INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f}); INDArray sigmoid = Transforms.sigmoid(n, false); @@ -1538,7 +1676,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNeg() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNeg(Nd4jBackend backend) { INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4}); INDArray neg = Transforms.neg(n); @@ -1547,7 +1687,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNorm2Double() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm2Double(Nd4jBackend backend) { DataType initialType = Nd4j.dataType(); Nd4j.setDataType(DataType.DOUBLE); @@ -1567,7 +1709,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testNorm2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm2(Nd4jBackend backend) { INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); float assertion = 5.47722557505f; float norm3 = n.norm2Number().floatValue(); @@ -1585,7 +1729,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testCosineSim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCosineSim(Nd4jBackend backend) { INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4}); double sim = Transforms.cosineSim(vec1, vec2); @@ -1600,7 +1746,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testScal() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScal(Nd4jBackend backend) { double assertion = 2; INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8}); INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer); @@ -1616,7 +1764,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testExp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExp(Nd4jBackend backend) { INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f}); INDArray exped = Transforms.exp(n); @@ -1628,7 +1778,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSlices() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSlices(Nd4jBackend backend) { INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24, DataType.DOUBLE).data(), new long[] {4, 3, 2}); for (int i = 0; i < arr.slices(); i++) { assertEquals(2, arr.slice(i).slice(1).slices()); @@ -1638,7 +1790,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalar(Nd4jBackend backend) { INDArray a = Nd4j.scalar(1.0f); assertEquals(true, a.isScalar()); @@ -1648,7 +1802,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testWrap() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testWrap(Nd4jBackend backend) { int[] shape = {2, 4}; INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]); INDArray n = d; @@ -1675,7 +1831,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testVectorInit() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorInit(Nd4jBackend backend) { DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(); INDArray arr = Nd4j.create(data, new long[] {1, 4}); assertEquals(true, arr.isRowVector()); @@ -1688,7 +1846,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testColumns() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testColumns(Nd4jBackend backend) { INDArray arr = Nd4j.create(new long[] {3, 2}); INDArray column2 = arr.getColumn(0); //assertEquals(true, Shape.shapeEquals(new long[]{3, 1}, column2.shape())); @@ -1729,7 +1889,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testPutRow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutRow(Nd4jBackend backend) { INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray slice1 = d.slice(1); INDArray n = d.dup(); @@ -1796,7 +1958,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testMulRowVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMulRowVector(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); arr.muliRowVector(Nd4j.linspace(1, 2, 2, DataType.DOUBLE)); INDArray assertion = Nd4j.create(new double[][] {{1, 4}, {3, 8}}); @@ -1807,7 +1971,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum(Nd4jBackend backend) { INDArray n = Nd4j.create(Nd4j.linspace(1, 8, 8, DataType.DOUBLE).data(), new long[] {2, 2, 2}); INDArray test = Nd4j.create(new double[] {3, 7, 11, 15}, new long[] {2, 2}); INDArray sum = n.sum(-1); @@ -1818,7 +1984,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testInplaceTranspose() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInplaceTranspose(Nd4jBackend backend) { INDArray test = Nd4j.rand(3, 4); INDArray orig = test.dup(); INDArray transposei = test.transposei(); @@ -1831,7 +1999,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTADMMul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTADMMul(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); val shape = new long[] {4, 5, 7}; INDArray arr = Nd4j.rand(shape); @@ -1859,7 +2029,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTADMMulLeadingOne() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTADMMulLeadingOne(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); val shape = new long[] {1, 5, 7}; INDArray arr = Nd4j.rand(shape); @@ -1889,7 +2061,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSum2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum2(Nd4jBackend backend) { INDArray test = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray sum = test.sum(1); INDArray assertion = Nd4j.create(new float[] {3, 7}); @@ -1900,7 +2074,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testGetIntervalEdgeCase() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetIntervalEdgeCase(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int[] shape = {3, 2, 4}; @@ -1944,7 +2120,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testGetIntervalEdgeCase2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetIntervalEdgeCase2(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int[] shape = {3, 2, 4}; @@ -1968,7 +2146,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testMmul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmul(Nd4jBackend backend) { DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data(); INDArray n = Nd4j.create(data, new long[] {1, 10}); INDArray transposed = n.transpose(); @@ -2035,7 +2215,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testRowsColumns() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRowsColumns(Nd4jBackend backend) { DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data(); INDArray rows = Nd4j.create(data, new long[] {2, 3}); assertEquals(2, rows.rows()); @@ -2051,7 +2233,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testTranspose() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTranspose(Nd4jBackend backend) { INDArray n = Nd4j.create(Nd4j.ones(100).data(), new long[] {5, 5, 4}).castTo(DataType.DOUBLE); INDArray transpose = n.transpose(); assertEquals(n.length(), transpose.length()); @@ -2074,7 +2258,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testLogX1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogX1(Nd4jBackend backend) { INDArray x = Nd4j.create(10).assign(7); INDArray logX5 = Transforms.log(x, 5, true); @@ -2085,7 +2271,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAddMatrix() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddMatrix(Nd4jBackend backend) { INDArray five = Nd4j.ones(5); five.addi(five); INDArray twos = Nd4j.valueArrayOf(5, 2); @@ -2095,7 +2283,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testPutSlice() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutSlice(Nd4jBackend backend) { INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3); INDArray newSlice = Nd4j.zeros(3, 3); n.putSlice(0, newSlice); @@ -2105,14 +2295,16 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testRowVectorMultipleIndices() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRowVectorMultipleIndices(Nd4jBackend backend) { INDArray linear = Nd4j.create(1, 4); linear.putScalar(new long[] {0, 1}, 1); assertEquals(linear.getDouble(0, 1), 1, 1e-1); } @Test() - public void testSize() { + public void testSize(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { INDArray arr = Nd4j.create(4, 5); @@ -2126,7 +2318,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNullPointerDataBuffer() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNullPointerDataBuffer(Nd4jBackend backend) { DataType initialType = Nd4j.dataType(); Nd4j.setDataType(DataType.FLOAT); @@ -2142,7 +2336,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testEps() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEps(Nd4jBackend backend) { INDArray ones = Nd4j.ones(5); val res = Nd4j.create(DataType.BOOL, 5); Nd4j.getExecutioner().exec(new Eps(ones, ones, res)); @@ -2152,7 +2348,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testEps2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEps2(Nd4jBackend backend) { INDArray first = Nd4j.valueArrayOf(10, 1e-2); //0.01 INDArray second = Nd4j.zeros(10); //0.0 @@ -2168,7 +2366,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testLogDouble() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogDouble(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); INDArray log = Transforms.log(linspace); INDArray assertion = Nd4j.create(new double[] {0, 0.6931471805599453, 1.0986122886681098, 1.3862943611198906, @@ -2177,14 +2377,18 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testDupDimension() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDupDimension(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); assertEquals(arr.tensorAlongDimension(0, 1), arr.tensorAlongDimension(0, 1)); } @Test - public void testIterator() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIterator(Nd4jBackend backend) { INDArray x = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray repeated = x.repeat(1, 2); assertEquals(8, repeated.length()); @@ -2195,7 +2399,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTile() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTile(Nd4jBackend backend) { INDArray x = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray repeated = x.repeat(0, 2); assertEquals(8, repeated.length()); @@ -2211,7 +2417,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNegativeOneReshape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNegativeOneReshape(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[] {0, 1, 2}); INDArray newShape = arr.reshape(-1); assertEquals(newShape, arr); @@ -2219,7 +2427,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSmallSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSmallSum(Nd4jBackend backend) { INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007}); base.addi(1e-12); INDArray assertion = Nd4j.create(new double[] {5.84333433, 3.054001}); @@ -2229,7 +2439,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void test2DArraySlice() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test2DArraySlice(Nd4jBackend backend) { INDArray array2D = Nd4j.ones(5, 7); /** * This should be reverse. @@ -2256,7 +2468,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test @Disabled - public void testTensorDot() { + public void testTensorDot(Nd4jBackend backend) { INDArray oneThroughSixty = Nd4j.arange(60).reshape(3, 4, 5).castTo(DataType.DOUBLE); INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape(4, 3, 2).castTo(DataType.DOUBLE); INDArray result = Nd4j.tensorMmul(oneThroughSixty, oneThroughTwentyFour, new int[][] {{1, 0}, {0, 1}}); @@ -2281,7 +2493,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testGetRow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRow(Nd4jBackend backend) { INDArray arr = Nd4j.ones(10, 4); for (int i = 0; i < 10; i++) { INDArray row = arr.getRow(i); @@ -2291,7 +2505,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testGetPermuteReshapeSub() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetPermuteReshapeSub(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray first = Nd4j.rand(new long[] {10, 4}); @@ -2312,7 +2528,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testPutAtIntervalIndexWithStride() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutAtIntervalIndexWithStride(Nd4jBackend backend) { INDArray n1 = Nd4j.create(3, 3).assign(0.0); INDArrayIndex[] indices = {NDArrayIndex.interval(0, 2, 3), NDArrayIndex.all()}; n1.put(indices, 1); @@ -2321,7 +2539,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMMulMatrixTimesColVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMMulMatrixTimesColVector(Nd4jBackend backend) { //[1 1 1 1 1; 10 10 10 10 10; 100 100 100 100 100] x [1; 1; 1; 1; 1] = [5; 50; 500] INDArray matrix = Nd4j.ones(3, 5); matrix.getRow(1).muli(10); @@ -2336,7 +2556,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testMMulMixedOrder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMMulMixedOrder(Nd4jBackend backend) { INDArray first = Nd4j.ones(5, 2); INDArray second = Nd4j.ones(2, 3); INDArray out = first.mmul(second); @@ -2360,7 +2582,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testFTimesCAddiRow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFTimesCAddiRow(Nd4jBackend backend) { INDArray arrF = Nd4j.create(2, 3, 'f').assign(1.0); INDArray arrC = Nd4j.create(2, 3, 'c').assign(1.0); @@ -2387,7 +2611,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testMmulGet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmulGet(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345L); INDArray elevenByTwo = Nd4j.rand(new long[] {11, 2}); INDArray twoByEight = Nd4j.rand(new long[] {2, 8}); @@ -2404,7 +2630,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testMMulRowColVectorMixedOrder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMMulRowColVectorMixedOrder(Nd4jBackend backend) { INDArray colVec = Nd4j.ones(5, 1); INDArray rowVec = Nd4j.ones(1, 3); INDArray out = colVec.mmul(rowVec); @@ -2427,7 +2655,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMMulFTimesC() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMMulFTimesC(Nd4jBackend backend) { int nRows = 3; int nCols = 3; java.util.Random r = new java.util.Random(12345); @@ -2452,7 +2682,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMMulColVectorRowVectorMixedOrder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMMulColVectorRowVectorMixedOrder(Nd4jBackend backend) { INDArray colVec = Nd4j.ones(5, 1); INDArray rowVec = Nd4j.ones(1, 5); INDArray out = rowVec.mmul(colVec); @@ -2474,7 +2706,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testPermute() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPermute(Nd4jBackend backend) { INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4}); INDArray transpose = n.transpose(); INDArray permute = n.permute(1, 0); @@ -2489,7 +2723,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testPermutei() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPermutei(Nd4jBackend backend) { //Check in-place permute vs. copy array permute //2d: @@ -2570,7 +2806,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testPermuteiShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPermuteiShape(Nd4jBackend backend) { INDArray row = Nd4j.create(1, 10); @@ -2604,7 +2842,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSwapAxes() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSwapAxes(Nd4jBackend backend) { INDArray n = Nd4j.create(Nd4j.linspace(0, 7, 8, DataType.DOUBLE).data(), new long[] {2, 2, 2}); INDArray assertion = n.permute(2, 1, 0); INDArray permuteTranspose = assertion.slice(1).slice(1); @@ -2622,7 +2862,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testMuliRowVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMuliRowVector(Nd4jBackend backend) { INDArray arrC = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape('c', 3, 2); INDArray arrF = Nd4j.create(new long[] {3, 2}, 'f').assign(arrC); @@ -2647,7 +2889,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSliceConstructor() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSliceConstructor(Nd4jBackend backend) { List testList = new ArrayList<>(); for (int i = 0; i < 5; i++) testList.add(Nd4j.scalar(i + 1.0f)); @@ -2660,7 +2904,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testStdev0() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStdev0(Nd4jBackend backend) { double[][] ind = {{5.1, 3.5, 1.4}, {4.9, 3.0, 1.4}, {4.7, 3.2, 1.3}}; INDArray in = Nd4j.create(ind); INDArray stdev = in.std(0); @@ -2670,7 +2916,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testStdev1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStdev1(Nd4jBackend backend) { double[][] ind = {{5.1, 3.5, 1.4}, {4.9, 3.0, 1.4}, {4.7, 3.2, 1.3}}; INDArray in = Nd4j.create(ind); INDArray stdev = in.std(1); @@ -2681,7 +2929,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSignXZ() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSignXZ(Nd4jBackend backend) { double[] d = {1.0, -1.1, 1.2, 1.3, -1.4, -1.5, 1.6, -1.7, -1.8, -1.9, -1.01, -1.011}; double[] e = {1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0}; @@ -2715,7 +2965,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTanhXZ() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTanhXZ(Nd4jBackend backend) { INDArray arrC = Nd4j.linspace(-6, 6, 12, DataType.DOUBLE).reshape('c', 4, 3); INDArray arrF = Nd4j.create(new long[] {4, 3}, 'f').assign(arrC); double[] d = arrC.data().asDouble(); @@ -2750,7 +3002,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcastDiv() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastDiv(Nd4jBackend backend) { INDArray num = Nd4j.create(new double[] {1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00, 1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00, -1.00, -1.00, -1.00, -1.00, -2.00, -2.00, -2.00, -2.00, -1.00, -1.00, -1.00, -1.00, -2.00, -2.00, -2.00, -2.00}).reshape(2, 16); @@ -2768,6 +3022,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBroadcastDiv2(){ INDArray arr = Nd4j.ones(DataType.DOUBLE, 1, 64, 125, 125).muli(2); INDArray vec = Nd4j.ones(DataType.DOUBLE, 64).muli(2); @@ -2783,7 +3039,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcastMult() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastMult(Nd4jBackend backend) { INDArray num = Nd4j.create(new double[] {1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, -1.00, -2.00, -3.00, -4.00, -5.00, -6.00, -7.00, -8.00}).reshape(2, 8); @@ -2797,7 +3055,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcastSub() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastSub(Nd4jBackend backend) { INDArray num = Nd4j.create(new double[] {1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, -1.00, -2.00, -3.00, -4.00, -5.00, -6.00, -7.00, -8.00}).reshape(2, 8); @@ -2811,7 +3071,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcastAdd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastAdd(Nd4jBackend backend) { INDArray num = Nd4j.create(new double[] {1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, -1.00, -2.00, -3.00, -4.00, -5.00, -6.00, -7.00, -8.00}).reshape(2, 8); @@ -2825,7 +3087,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testDimension() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDimension(Nd4jBackend backend) { INDArray test = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2}); //row INDArray slice0 = test.slice(0, 1); @@ -2859,7 +3123,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testReshape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshape(Nd4jBackend backend) { INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24, DataType.DOUBLE).data(), new long[] {4, 3, 2}); INDArray reshaped = arr.reshape(2, 3, 4); assertEquals(arr.length(), reshaped.length()); @@ -2871,6 +3137,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDot() throws Exception { INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4}); @@ -2889,7 +3157,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIdentity() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIdentity(Nd4jBackend backend) { INDArray eye = Nd4j.eye(5); assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape())); eye = Nd4j.eye(5); @@ -2897,7 +3167,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTemp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTemp(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); INDArray in = Nd4j.rand(new long[] {2, 2, 2}); // System.out.println("In:\n" + in); @@ -2914,7 +3186,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testMeans() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMeans(Nd4jBackend backend) { INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray mean1 = a.mean(1); assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage()); @@ -2926,7 +3200,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSums() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSums(Nd4jBackend backend) { INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage()); assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage()); @@ -2936,7 +3212,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testRSubi() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRSubi(Nd4jBackend backend) { INDArray n2 = Nd4j.ones(2); INDArray n2Assertion = Nd4j.zeros(2); INDArray nRsubi = n2.rsubi(1); @@ -2945,7 +3223,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testConcat() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat(Nd4jBackend backend) { INDArray A = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(2, 2, 2); INDArray B = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray concat = Nd4j.concat(0, A, B); @@ -2959,7 +3239,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testConcatHorizontally() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatHorizontally(Nd4jBackend backend) { INDArray rowVector = Nd4j.ones(1, 5); INDArray other = Nd4j.ones(1, 5); INDArray concat = Nd4j.hstack(other, rowVector); @@ -2970,7 +3252,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testArgMaxSameValues() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArgMaxSameValues(Nd4jBackend backend) { //Here: assume that by convention, argmax returns the index of the FIRST maximum value //Thus, argmax(ones(...)) = 0 by convention INDArray arr = Nd4j.ones(DataType.DOUBLE,1,10); @@ -2984,7 +3268,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSoftmaxStability() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftmaxStability(Nd4jBackend backend) { INDArray input = Nd4j.create(new double[] {-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose(); // System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo())); INDArray output = Nd4j.create(10, 1); @@ -2993,7 +3279,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAssignOffset() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssignOffset(Nd4jBackend backend) { INDArray arr = Nd4j.ones(5, 5); INDArray row = arr.slice(1); row.assign(1); @@ -3001,7 +3289,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAddScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddScalar(Nd4jBackend backend) { INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4); INDArray rdiv = div.add(1); INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 5); @@ -3009,7 +3299,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testRdivScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRdivScalar(Nd4jBackend backend) { INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4).castTo(DataType.DOUBLE); INDArray rdiv = div.rdiv(1); INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 0.25).castTo(DataType.DOUBLE); @@ -3017,7 +3309,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testRDivi() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRDivi(Nd4jBackend backend) { INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4).castTo(DataType.DOUBLE); INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5).castTo(DataType.DOUBLE); INDArray nRsubi = n2.rdivi(2); @@ -3027,7 +3321,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testElementWiseAdd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testElementWiseAdd(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray linspace2 = linspace.dup(); INDArray assertion = Nd4j.create(new double[][] {{2, 4}, {6, 8}}); @@ -3036,7 +3332,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSquareMatrix() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSquareMatrix(Nd4jBackend backend) { INDArray n = Nd4j.create(Nd4j.linspace(1, 8, 8, DataType.DOUBLE).data(), new long[] {2, 2, 2}); INDArray eightFirstTest = n.vectorAlongDimension(0, 2); INDArray eightFirstAssertion = Nd4j.create(new double[] {1, 2}); @@ -3049,7 +3347,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNumVectorsAlongDimension() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNumVectorsAlongDimension(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); assertEquals(12, arr.vectorsAlongDimension(2)); } @@ -3057,7 +3357,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testBroadCast() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadCast(Nd4jBackend backend) { INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray broadCasted = n.broadcast(5, 4); for (int i = 0; i < broadCasted.rows(); i++) { @@ -3086,7 +3388,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testScalarBroadcast() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarBroadcast(Nd4jBackend backend) { INDArray fiveThree = Nd4j.ones(5, 3); INDArray fiveThreeTest = Nd4j.scalar(1.0).broadcast(5, 3); assertEquals(fiveThree, fiveThreeTest); @@ -3095,7 +3399,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testPutRowGetRowOrdering() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutRowGetRowOrdering(Nd4jBackend backend) { INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray put = Nd4j.create(new double[] {5, 6}); row1.putRow(1, put); @@ -3116,7 +3422,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testElementWiseOps() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testElementWiseOps(Nd4jBackend backend) { INDArray n1 = Nd4j.scalar(1.0); INDArray n2 = Nd4j.scalar(2.0); INDArray nClone = n1.add(n2); @@ -3137,7 +3445,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNdArrayCreation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNdArrayCreation(Nd4jBackend backend) { double delta = 1e-1; INDArray n1 = Nd4j.create(new double[] {0d, 1d, 2d, 3d}, new long[] {2, 2}, 'c'); INDArray lv = n1.reshape(-1); @@ -3148,7 +3458,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testToFlattenedWithOrder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToFlattenedWithOrder(Nd4jBackend backend) { int[] firstShape = {10, 3}; int firstLen = ArrayUtil.prod(firstShape); int[] secondShape = {2, 7}; @@ -3186,7 +3498,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testLeakyRelu() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLeakyRelu(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(-1, 1, 10, DataType.DOUBLE); double[] expected = new double[10]; for (int i = 0; i < 10; i++) { @@ -3201,7 +3515,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSoftmaxRow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftmaxRow(Nd4jBackend backend) { for (int i = 0; i < 20; i++) { INDArray arr1 = Nd4j.zeros(1, 100); Nd4j.getExecutioner().execAndReturn(new SoftMax(arr1)); @@ -3210,7 +3526,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testLeakyRelu2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLeakyRelu2(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(-1, 1, 10, DataType.DOUBLE); double[] expected = new double[10]; for (int i = 0; i < 10; i++) { @@ -3228,7 +3546,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testDupAndDupWithOrder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDupAndDupWithOrder(Nd4jBackend backend) { List> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(ordering(), 4, 5, 123, DataType.DOUBLE); for (Pair pair : testInputs) { @@ -3248,7 +3568,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testToOffsetZeroCopy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToOffsetZeroCopy(Nd4jBackend backend) { List> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(ordering(), 4, 5, 123, DataType.DOUBLE); @@ -3282,13 +3604,15 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test @Disabled - public void largeInstantiation() { + public void largeInstantiation(Nd4jBackend backend) { Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024 - 1)); // Still works; this can even be called as often as I want, allowing me even to spill over on disk Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024)); // Crashes } @Test - public void testAssignNumber() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssignNumber(Nd4jBackend backend) { int nRows = 10; int nCols = 20; INDArray in = Nd4j.linspace(1, nRows * nCols, nRows * nCols, DataType.DOUBLE).reshape('c', new long[] {nRows, nCols}); @@ -3317,7 +3641,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSumDifferentOrdersSquareMatrix() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSumDifferentOrdersSquareMatrix(Nd4jBackend backend) { INDArray arrc = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray arrf = Nd4j.create(new long[] {2, 2}, 'f').assign(arrc); @@ -3329,7 +3655,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test @Disabled //not relevant anymore - public void testAssignMixedC() { + public void testAssignMixedC(Nd4jBackend backend) { int[] shape1 = {3, 2, 2, 2, 2, 2}; int[] shape2 = {12, 8}; int length = ArrayUtil.prod(shape1); @@ -3358,7 +3684,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testDummy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDummy(Nd4jBackend backend) { INDArray arr2f = Nd4j.create(new double[] {1.0, 13.0, 25.0, 37.0, 49.0, 61.0, 73.0, 85.0, 2.0, 14.0, 26.0, 38.0, 50.0, 62.0, 74.0, 86.0, 3.0, 15.0, 27.0, 39.0, 51.0, 63.0, 75.0, 87.0, 4.0, 16.0, 28.0, 40.0, 52.0, 64.0, 76.0, 88.0, 5.0, 17.0, 29.0, 41.0, 53.0, 65.0, 77.0, 89.0, 6.0, 18.0, 30.0, 42.0, @@ -3384,7 +3712,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testCreateDetached_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateDetached_1(Nd4jBackend backend) { val shape = new int[]{10}; val dataTypes = new DataType[] {DataType.DOUBLE, DataType.BOOL, DataType.BYTE, DataType.UBYTE, DataType.SHORT, DataType.UINT16, DataType.INT, DataType.UINT32, DataType.LONG, DataType.UINT64, DataType.FLOAT, DataType.BFLOAT16, DataType.HALF}; @@ -3395,7 +3725,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testCreateDetached_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateDetached_2(Nd4jBackend backend) { val shape = new long[]{10}; val dataTypes = new DataType[] {DataType.DOUBLE, DataType.BOOL, DataType.BYTE, DataType.UBYTE, DataType.SHORT, DataType.UINT16, DataType.INT, DataType.UINT32, DataType.LONG, DataType.UINT64, DataType.FLOAT, DataType.BFLOAT16, DataType.HALF}; @@ -3406,7 +3738,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testPairwiseMixedC() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPairwiseMixedC(Nd4jBackend backend) { int[] shape2 = {12, 8}; int length = ArrayUtil.prod(shape2); @@ -3431,7 +3765,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testPairwiseMixedF() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPairwiseMixedF(Nd4jBackend backend) { int[] shape2 = {12, 8}; int length = ArrayUtil.prod(shape2); @@ -3456,7 +3792,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAssign2D() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssign2D(Nd4jBackend backend) { int[] shape2 = {8, 4}; int length = ArrayUtil.prod(shape2); @@ -3476,7 +3814,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAssign2D_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssign2D_2(Nd4jBackend backend) { int[] shape2 = {8, 4}; int length = ArrayUtil.prod(shape2); @@ -3504,7 +3844,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAssign3D_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssign3D_2(Nd4jBackend backend) { int[] shape3 = {8, 4, 8}; int length = ArrayUtil.prod(shape3); @@ -3526,7 +3868,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSumDifferentOrders() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSumDifferentOrders(Nd4jBackend backend) { INDArray arrc = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape('c', 3, 2); INDArray arrf = Nd4j.create(new double[6], new long[] {3, 2}, 'f').assign(arrc); @@ -3537,7 +3881,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testCreateUnitialized() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateUnitialized(Nd4jBackend backend) { INDArray arrC = Nd4j.createUninitialized(new long[] {10, 10}, 'c'); INDArray arrF = Nd4j.createUninitialized(new long[] {10, 10}, 'f'); @@ -3556,7 +3902,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testVarConst() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVarConst(Nd4jBackend backend) { INDArray x = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10); // System.out.println(x); assertFalse(Double.isNaN(x.var(0).sumNumber().doubleValue())); @@ -3600,7 +3948,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testVPull1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVPull1(Nd4jBackend backend) { int indexes[] = new int[] {0, 2, 4}; INDArray array = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(5, 5); INDArray assertion = Nd4j.createUninitialized(new long[] {3, 5}, 'f'); @@ -3616,7 +3966,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testPullRowsValidation1() { + public void testPullRowsValidation1(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2}); @@ -3624,7 +3974,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testPullRowsValidation2() { + public void testPullRowsValidation2(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2}); @@ -3632,7 +3982,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testPullRowsValidation3() { + public void testPullRowsValidation3(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10}); @@ -3640,7 +3990,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testPullRowsValidation4() { + public void testPullRowsValidation4(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3}); @@ -3648,7 +3998,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testPullRowsValidation5() { + public void testPullRowsValidation5(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e'); @@ -3658,7 +4008,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testVPull2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVPull2(Nd4jBackend backend) { val indexes = new int[] {0, 2, 4}; INDArray array = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(5, 5); INDArray assertion = Nd4j.createUninitialized(new long[] {3, 5}, 'c'); @@ -3678,7 +4030,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testCompareAndSet1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCompareAndSet1(Nd4jBackend backend) { INDArray array = Nd4j.zeros(25); INDArray assertion = Nd4j.zeros(25); @@ -3693,7 +4047,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReplaceNaNs() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReplaceNaNs(Nd4jBackend backend) { INDArray array = Nd4j.zeros(25); INDArray assertion = Nd4j.zeros(25); @@ -3711,7 +4067,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNaNEquality() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNaNEquality(Nd4jBackend backend) { INDArray array = Nd4j.zeros(25); INDArray assertion = Nd4j.zeros(25); @@ -3724,7 +4082,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSingleDeviceAveraging() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSingleDeviceAveraging(Nd4jBackend backend) { int LENGTH = 512 * 1024 * 2; INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0); INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0); @@ -3766,7 +4126,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testDistance1and2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDistance1and2(Nd4jBackend backend) { double[] d1 = new double[] {-1, 3, 2}; double[] d2 = new double[] {0, 1.5, -3.5}; INDArray arr1 = Nd4j.create(d1); @@ -3787,7 +4149,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testEqualsWithEps1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEqualsWithEps1(Nd4jBackend backend) { INDArray array1 = Nd4j.create(new double[] {0.5f, 1.5f, 2.5f, 3.5f, 4.5f}); INDArray array2 = Nd4j.create(new double[] {0f, 1f, 2f, 3f, 4f}); INDArray array3 = Nd4j.create(new double[] {0f, 1.000001f, 2f, 3f, 4f}); @@ -3800,7 +4164,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIMaxIAMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMaxIAMax(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL); INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01}); @@ -3816,7 +4182,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testIMinIAMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMinIAMin(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01}); INDArray abs = Transforms.abs(arr); val iaMin = new ArgAmin(abs); @@ -3831,7 +4199,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testBroadcast3d2d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcast3d2d(Nd4jBackend backend) { char[] orders = {'c', 'f'}; for (char orderArr : orders) { @@ -3879,7 +4249,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcast4d2d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcast4d2d(Nd4jBackend backend) { char[] orders = {'c', 'f'}; for (char orderArr : orders) { @@ -3998,7 +4370,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testIsMax2Of3d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsMax2Of3d(Nd4jBackend backend) { double[][][] slices = new double[3][][]; boolean[][][] isMax = new boolean[3][][]; @@ -4025,7 +4399,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIsMax2of4d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsMax2of4d(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); val s = new long[] {2, 3, 4, 5}; @@ -4101,7 +4477,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testIMax2Of3d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMax2Of3d(Nd4jBackend backend) { double[][][] slices = new double[3][][]; slices[0] = new double[][] {{1, 10, 2}, {3, 4, 5}}; @@ -4127,7 +4505,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testIMax2of4d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMax2of4d(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); val s = new long[] {2, 3, 4, 5}; INDArray arr = Nd4j.rand(s); @@ -4200,7 +4580,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTadPermuteEquals() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTadPermuteEquals(Nd4jBackend backend) { INDArray d3c = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape('c', 1, 5, 1); INDArray d3f = d3c.dup('f'); @@ -4225,7 +4607,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testRemainder1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRemainder1(Nd4jBackend backend) { INDArray x = Nd4j.create(10).assign(5.3); INDArray y = Nd4j.create(10).assign(2.0); INDArray exp = Nd4j.create(10).assign(-0.7); @@ -4238,7 +4622,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testFMod1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFMod1(Nd4jBackend backend) { INDArray x = Nd4j.create(10).assign(5.3); INDArray y = Nd4j.create(10).assign(2.0); INDArray exp = Nd4j.create(10).assign(1.3); @@ -4251,7 +4637,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testStrangeDups1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStrangeDups1(Nd4jBackend backend) { INDArray array = Nd4j.create(10).assign(0); INDArray exp = Nd4j.create(10).assign(1.0f); INDArray copy = null; @@ -4266,7 +4654,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testStrangeDups2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStrangeDups2(Nd4jBackend backend) { INDArray array = Nd4j.create(10).assign(0); INDArray exp1 = Nd4j.create(10).assign(1.0f); INDArray exp2 = Nd4j.create(10).assign(1.0f).putScalar(9, 0f); @@ -4282,7 +4672,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReductionAgreement1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReductionAgreement1(Nd4jBackend backend) { INDArray row = Nd4j.linspace(1, 3, 3, DataType.DOUBLE).reshape(1, 3); INDArray mean0 = row.mean(0); assertFalse(mean0 == row); //True: same object (should be a copy) @@ -4294,7 +4686,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSpecialConcat1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSpecialConcat1(Nd4jBackend backend) { for (int i = 0; i < 10; i++) { List arrays = new ArrayList<>(); for (int x = 0; x < 10; x++) { @@ -4314,7 +4708,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testSpecialConcat2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSpecialConcat2(Nd4jBackend backend) { List arrays = new ArrayList<>(); for (int x = 0; x < 10; x++) { arrays.add(Nd4j.create(new double[] {x, x, x, x, x, x}).reshape(1, 6)); @@ -4333,7 +4729,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testPutScalar1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutScalar1(Nd4jBackend backend) { INDArray array = Nd4j.create(10, 3, 96, 96); for (int i = 0; i < 10; i++) { @@ -4343,7 +4741,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAveraging1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAveraging1(Nd4jBackend backend) { Nd4j.getAffinityManager().allowCrossDeviceAccess(false); List arrays = new ArrayList<>(); @@ -4361,7 +4761,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAveraging2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAveraging2(Nd4jBackend backend) { List arrays = new ArrayList<>(); for (int i = 0; i < 10; i++) { @@ -4380,7 +4782,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAveraging3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAveraging3(Nd4jBackend backend) { Nd4j.getAffinityManager().allowCrossDeviceAccess(false); List arrays = new ArrayList<>(); @@ -4400,7 +4804,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testZ1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testZ1(Nd4jBackend backend) { INDArray matrix = Nd4j.create(10, 10).assign(1.0); INDArray exp = Nd4j.create(10).assign(10.0); @@ -4414,7 +4820,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testDupDelayed() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDupDelayed(Nd4jBackend backend) { if (!(Nd4j.getExecutioner() instanceof GridExecutioner)) return; @@ -4464,7 +4872,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testScalarReduction1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarReduction1(Nd4jBackend backend) { val op = new Norm2(Nd4j.create(1).assign(1.0)); double norm2 = Nd4j.getExecutioner().execAndReturn(op).getFinalResult().doubleValue(); double norm1 = Nd4j.getExecutioner().execAndReturn(new Norm1(Nd4j.create(1).assign(1.0))).getFinalResult() @@ -4479,7 +4889,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void tesAbsReductions1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void tesAbsReductions1(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {-1, -2, -3, -4}); assertEquals(4, array.amaxNumber().intValue()); @@ -4487,7 +4899,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void tesAbsReductions2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void tesAbsReductions2(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {-1, -2, -3, -4}); assertEquals(1, array.aminNumber().intValue()); @@ -4495,7 +4909,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void tesAbsReductions3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void tesAbsReductions3(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {-2, -2, 2, 2}); assertEquals(2, array.ameanNumber().intValue()); @@ -4503,7 +4919,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void tesAbsReductions4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void tesAbsReductions4(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {-2, -2, 2, 3}); assertEquals(1.0, array.sumNumber().doubleValue(), 1e-5); @@ -4511,14 +4929,18 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void tesAbsReductions5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void tesAbsReductions5(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {-2, 0.0, 2, 2}); assertEquals(3, array.scan(Conditions.absGreaterThan(0.0)).intValue()); } @Test - public void testNewBroadcastComparison1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewBroadcastComparison1(Nd4jBackend backend) { val initial = Nd4j.create(3, 5); val mask = Nd4j.create(new double[] {5, 4, 3, 2, 1}); val result = Nd4j.createUninitialized(DataType.BOOL, initial.shape()); @@ -4545,7 +4967,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testNewBroadcastComparison2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewBroadcastComparison2(Nd4jBackend backend) { val initial = Nd4j.create(3, 5); val mask = Nd4j.create(new double[] {5, 4, 3, 2, 1}); val result = Nd4j.createUninitialized(DataType.BOOL, initial.shape()); @@ -4569,7 +4993,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testNewBroadcastComparison3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewBroadcastComparison3(Nd4jBackend backend) { val initial = Nd4j.create(3, 5); val mask = Nd4j.create(new double[] {5, 4, 3, 2, 1}); val result = Nd4j.createUninitialized(DataType.BOOL, initial.shape()); @@ -4591,7 +5017,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNewBroadcastComparison4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewBroadcastComparison4(Nd4jBackend backend) { val initial = Nd4j.create(3, 5); val mask = Nd4j.create(new double[] {5, 4, 3, 2, 1}); val result = Nd4j.createUninitialized(DataType.BOOL, initial.shape()); @@ -4613,7 +5041,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTadDup_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTadDup_1(Nd4jBackend backend) { INDArray haystack = Nd4j.create(new double[] {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951, -0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673, 0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247}) @@ -4628,7 +5058,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTadReduce3_0() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTadReduce3_0(Nd4jBackend backend) { INDArray haystack = Nd4j.create(new double[] {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951, -0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673, 0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247}) @@ -4649,7 +5081,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReduce3SignaturesEquality_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduce3SignaturesEquality_1(Nd4jBackend backend) { val x = Nd4j.rand(DataType.DOUBLE, 3, 4, 5); val y = Nd4j.rand(DataType.DOUBLE, 3, 4, 5); @@ -4663,7 +5097,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTadReduce3_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTadReduce3_1(Nd4jBackend backend) { INDArray initial = Nd4j.create(5, 10); for (int i = 0; i < initial.rows(); i++) { initial.getRow(i).assign(i + 1); @@ -4681,7 +5117,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTadReduce3_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTadReduce3_2(Nd4jBackend backend) { INDArray initial = Nd4j.create(5, 10); for (int i = 0; i < initial.rows(); i++) { initial.getRow(i).assign(i + 1); @@ -4699,7 +5137,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTadReduce3_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTadReduce3_3(Nd4jBackend backend) { INDArray initial = Nd4j.create(5, 10); for (int i = 0; i < initial.rows(); i++) { initial.getRow(i).assign(i + 1); @@ -4718,7 +5158,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTadReduce3_3_NEG() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTadReduce3_3_NEG(Nd4jBackend backend) { INDArray initial = Nd4j.create(5, 10); for (int i = 0; i < initial.rows(); i++) { initial.getRow(i).assign(i + 1); @@ -4737,7 +5179,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTadReduce3_3_NEG_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTadReduce3_3_NEG_2(Nd4jBackend backend) { INDArray initial = Nd4j.create(5, 10); for (int i = 0; i < initial.rows(); i++) { initial.getRow(i).assign(i + 1); @@ -4757,7 +5201,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testTadReduce3_5() { + public void testTadReduce3_5(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { INDArray initial = Nd4j.create(5, 10); for (int i = 0; i < initial.rows(); i++) { @@ -4771,7 +5215,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTadReduce3_4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTadReduce3_4(Nd4jBackend backend) { INDArray initial = Nd4j.create(5, 6, 7); for (int i = 0; i < 5; i++) { initial.tensorAlongDimension(i, 1, 2).assign(i + 1); @@ -4790,7 +5236,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAtan2_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAtan2_1(Nd4jBackend backend) { INDArray x = Nd4j.create(10).assign(-1.0); INDArray y = Nd4j.create(10).assign(0.0); INDArray exp = Nd4j.create(10).assign(Math.PI); @@ -4801,7 +5249,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAtan2_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAtan2_2(Nd4jBackend backend) { INDArray x = Nd4j.create(10).assign(1.0); INDArray y = Nd4j.create(10).assign(0.0); INDArray exp = Nd4j.create(10).assign(0.0); @@ -4812,7 +5262,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testJaccardDistance1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJaccardDistance1(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 1, 0, 0, 1, 0}); INDArray y = Nd4j.create(new double[] {1, 1, 0, 1, 0, 0}); @@ -4822,7 +5274,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testJaccardDistance2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJaccardDistance2(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 1, 0, 0, 1, 1}); INDArray y = Nd4j.create(new double[] {1, 1, 0, 1, 0, 0}); @@ -4832,7 +5286,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testHammingDistance1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHammingDistance1(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 0, 0, 1, 0, 0}); INDArray y = Nd4j.create(new double[] {0, 0, 0, 0, 1, 0}); @@ -4842,7 +5298,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testHammingDistance2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHammingDistance2(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 0, 0, 1, 0, 0}); INDArray y = Nd4j.create(new double[] {0, 1, 0, 0, 1, 0}); @@ -4852,7 +5310,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testHammingDistance3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHammingDistance3(Nd4jBackend backend) { INDArray x = Nd4j.create(DataType.DOUBLE, 10, 6); for (int r = 0; r < x.rows(); r++) { val p = r % x.columns(); @@ -4874,7 +5334,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAllDistances1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllDistances1(Nd4jBackend backend) { INDArray initialX = Nd4j.create(5, 10); INDArray initialY = Nd4j.create(7, 10); for (int i = 0; i < initialX.rows(); i++) { @@ -4906,7 +5368,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAllDistances2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllDistances2(Nd4jBackend backend) { INDArray initialX = Nd4j.create(5, 10); INDArray initialY = Nd4j.create(7, 10); for (int i = 0; i < initialX.rows(); i++) { @@ -4936,7 +5400,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAllDistances2_Large() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllDistances2_Large(Nd4jBackend backend) { INDArray initialX = Nd4j.create(5, 2000); INDArray initialY = Nd4j.create(7, 2000); for (int i = 0; i < initialX.rows(); i++) { @@ -4966,7 +5432,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAllDistances3_Large() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllDistances3_Large(Nd4jBackend backend) { INDArray initialX = Nd4j.create(5, 2000); INDArray initialY = Nd4j.create(7, 2000); for (int i = 0; i < initialX.rows(); i++) { @@ -4998,7 +5466,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAllDistances3_Large_Columns() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllDistances3_Large_Columns(Nd4jBackend backend) { INDArray initialX = Nd4j.create(2000, 5); INDArray initialY = Nd4j.create(2000, 7); for (int i = 0; i < initialX.columns(); i++) { @@ -5028,7 +5498,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAllDistances4_Large_Columns() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllDistances4_Large_Columns(Nd4jBackend backend) { INDArray initialX = Nd4j.create(2000, 5); INDArray initialY = Nd4j.create(2000, 7); for (int i = 0; i < initialX.columns(); i++) { @@ -5058,7 +5530,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAllDistances5_Large_Columns() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllDistances5_Large_Columns(Nd4jBackend backend) { INDArray initialX = Nd4j.create(2000, 5); INDArray initialY = Nd4j.create(2000, 7); for (int i = 0; i < initialX.columns(); i++) { @@ -5088,7 +5562,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAllDistances3_Small_Columns() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllDistances3_Small_Columns(Nd4jBackend backend) { INDArray initialX = Nd4j.create(200, 5); INDArray initialY = Nd4j.create(200, 7); for (int i = 0; i < initialX.columns(); i++) { @@ -5117,7 +5593,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAllDistances3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllDistances3(Nd4jBackend backend) { Nd4j.getRandom().setSeed(123); INDArray initialX = Nd4j.rand(5, 10); @@ -5142,7 +5620,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testStridedTransforms1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedTransforms1(Nd4jBackend backend) { //output: Rank: 2,Offset: 0 //Order: c Shape: [5,2], stride: [2,1] //output: [0.5086864, 0.49131358, 0.50720876, 0.4927912, 0.46074104, 0.53925896, 0.49314, 0.50686, 0.5217741, 0.4782259] @@ -5170,7 +5650,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testEntropy1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEntropy1(Nd4jBackend backend) { INDArray x = Nd4j.rand(1, 100); double exp = MathUtils.entropy(x.data().asDouble()); @@ -5180,7 +5662,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testEntropy2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEntropy2(Nd4jBackend backend) { INDArray x = Nd4j.rand(10, 100); INDArray res = x.entropy(1); @@ -5195,7 +5679,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testEntropy3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEntropy3(Nd4jBackend backend) { INDArray x = Nd4j.rand(1, 100); double exp = getShannonEntropy(x.data().asDouble()); @@ -5205,7 +5691,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testEntropy4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEntropy4(Nd4jBackend backend) { INDArray x = Nd4j.rand(1, 100); double exp = getLogEntropy(x.data().asDouble()); @@ -5228,7 +5716,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReverse1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverse1(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); @@ -5238,7 +5728,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReverse2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverse2(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); @@ -5248,7 +5740,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReverse3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverse3(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); @@ -5258,7 +5752,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReverse4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverse4(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); @@ -5268,7 +5764,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReverse5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverse5(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); @@ -5280,7 +5778,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testReverse6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverse6(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); @@ -5291,7 +5791,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNativeSortView1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNativeSortView1(Nd4jBackend backend) { INDArray matrix = Nd4j.create(10, 10); INDArray exp = Nd4j.linspace(0, 9, 10, DataType.DOUBLE); int cnt = 0; @@ -5306,7 +5808,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNativeSort1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNativeSort1(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {9, 2, 1, 7, 6, 5, 4, 3, 8, 0}); INDArray exp1 = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); INDArray exp2 = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); @@ -5321,7 +5825,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNativeSort2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNativeSort2(Nd4jBackend backend) { INDArray array = Nd4j.rand(1, 10000); INDArray res = Nd4j.sort(array, true); @@ -5334,7 +5840,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNativeSort3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNativeSort3(Nd4jBackend backend) { int length = isIntegrationTests() ? 1048576 : 16484; INDArray array = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape(1, -1); INDArray exp = array.dup(); @@ -5349,6 +5857,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLongShapeDescriptor(){ Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); INDArray arr = Nd4j.create(new float[]{1,2,3}); @@ -5358,7 +5868,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReverseSmall_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverseSmall_1(Nd4jBackend backend) { val array = Nd4j.linspace(1, 10, 10, DataType.INT); val exp = array.dup(array.ordering()); @@ -5372,7 +5884,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReverseSmall_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverseSmall_2(Nd4jBackend backend) { val array = Nd4j.linspace(1, 10, 10, DataType.INT); val exp = array.dup(array.ordering()); @@ -5386,7 +5900,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReverseSmall_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverseSmall_3(Nd4jBackend backend) { val array = Nd4j.linspace(1, 11, 11, DataType.INT); val exp = array.dup(array.ordering()); @@ -5401,7 +5917,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReverseSmall_4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverseSmall_4(Nd4jBackend backend) { val array = Nd4j.linspace(1, 11, 11, DataType.INT); val exp = array.dup(array.ordering()); @@ -5415,7 +5933,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReverse_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverse_1(Nd4jBackend backend) { val array = Nd4j.linspace(1, 2017152, 2017152, DataType.INT); val exp = array.dup(array.ordering()); @@ -5429,7 +5949,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReverse_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverse_2(Nd4jBackend backend) { val array = Nd4j.linspace(1, 2017152, 2017152, DataType.INT); val exp = array.dup(array.ordering()); @@ -5443,7 +5965,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNativeSort3_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNativeSort3_1(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 2017152, 2017152, DataType.DOUBLE).reshape(1, -1); INDArray exp = array.dup(); Transforms.reverse(array, false); @@ -5457,7 +5981,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNativeSortAlongDimension1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNativeSortAlongDimension1(Nd4jBackend backend) { INDArray array = Nd4j.create(1000, 1000); INDArray exp1 = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE); INDArray dps = exp1.dup(); @@ -5499,7 +6025,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void shuffleTest() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void shuffleTest(Nd4jBackend backend) { for (int e = 0; e < 5; e++) { //log.info("---------------------"); val array = Nd4j.linspace(1, 1011, 1011, DataType.INT); @@ -5515,7 +6043,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNativeSortAlongDimension3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNativeSortAlongDimension3(Nd4jBackend backend) { INDArray array = Nd4j.create(2000, 2000); INDArray exp1 = Nd4j.linspace(1, 2000, 2000, DataType.DOUBLE); INDArray dps = exp1.dup(); @@ -5549,7 +6079,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testNativeSortAlongDimension2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNativeSortAlongDimension2(Nd4jBackend backend) { INDArray array = Nd4j.create(100, 10); INDArray exp1 = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); @@ -5566,7 +6098,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testPercentile1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPercentile1(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); Percentile percentile = new Percentile(50); double exp = percentile.evaluate(array.data().asDouble()); @@ -5575,7 +6109,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testPercentile2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPercentile2(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 9, 9, DataType.DOUBLE); Percentile percentile = new Percentile(50); double exp = percentile.evaluate(array.data().asDouble()); @@ -5585,7 +6121,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testPercentile3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPercentile3(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 9, 9, DataType.DOUBLE); Percentile percentile = new Percentile(75); double exp = percentile.evaluate(array.data().asDouble()); @@ -5594,7 +6132,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testPercentile4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPercentile4(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); Percentile percentile = new Percentile(75); double exp = percentile.evaluate(array.data().asDouble()); @@ -5603,14 +6143,18 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testPercentile5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPercentile5(Nd4jBackend backend) { val array = Nd4j.createFromArray(new int[]{1, 1982}); val perc = array.percentileNumber(75); assertEquals(1982.f, perc.floatValue(), 1e-5f); } @Test - public void testTadPercentile1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTadPercentile1(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); Transforms.reverse(array, false); Percentile percentile = new Percentile(75); @@ -5627,7 +6171,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testPutiRowVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutiRowVector(Nd4jBackend backend) { INDArray matrix = Nd4j.createUninitialized(10, 10); INDArray exp = Nd4j.create(10, 10).assign(1.0); INDArray row = Nd4j.create(10).assign(1.0); @@ -5638,7 +6184,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testPutiColumnsVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutiColumnsVector(Nd4jBackend backend) { INDArray matrix = Nd4j.createUninitialized(5, 10); INDArray exp = Nd4j.create(5, 10).assign(1.0); INDArray row = Nd4j.create(5, 1).assign(1.0); @@ -5651,7 +6199,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testRsub1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRsub1(Nd4jBackend backend) { INDArray arr = Nd4j.ones(5).assign(2.0); INDArray exp_0 = Nd4j.ones(5).assign(2.0); INDArray exp_1 = Nd4j.create(5).assign(-1); @@ -5665,7 +6215,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcastMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastMin(Nd4jBackend backend) { INDArray matrix = Nd4j.create(5, 5); for (int r = 0; r < matrix.rows(); r++) { matrix.getRow(r).assign(Nd4j.create(new double[]{2, 3, 3, 4, 5})); @@ -5681,7 +6233,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcastMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastMax(Nd4jBackend backend) { INDArray matrix = Nd4j.create(5, 5); for (int r = 0; r < matrix.rows(); r++) { matrix.getRow(r).assign(Nd4j.create(new double[]{1, 2, 3, 2, 1})); @@ -5697,7 +6251,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcastAMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastAMax(Nd4jBackend backend) { INDArray matrix = Nd4j.create(5, 5); for (int r = 0; r < matrix.rows(); r++) { matrix.getRow(r).assign(Nd4j.create(new double[]{1, 2, 3, 2, 1})); @@ -5713,7 +6269,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcastAMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastAMin(Nd4jBackend backend) { INDArray matrix = Nd4j.create(5, 5); for (int r = 0; r < matrix.rows(); r++) { matrix.getRow(r).assign(Nd4j.create(new double[]{2, 3, 3, 4, 1})); @@ -5730,7 +6288,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test @Disabled - public void testLogExpSum1() { + public void testLogExpSum1(Nd4jBackend backend) { INDArray matrix = Nd4j.create(3, 3); for (int r = 0; r < matrix.rows(); r++) { matrix.getRow(r).assign(Nd4j.create(new double[]{1, 2, 3})); @@ -5745,7 +6303,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test @Disabled - public void testLogExpSum2() { + public void testLogExpSum2(Nd4jBackend backend) { INDArray row = Nd4j.create(new double[]{1, 2, 3}); double res = Nd4j.getExecutioner().exec(new LogSumExp(row))[0].getDouble(0); @@ -5754,7 +6312,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testPow1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPow1(Nd4jBackend backend) { val argX = Nd4j.create(3).assign(2.0); val argY = Nd4j.create(new double[]{1.0, 2.0, 3.0}); val exp = Nd4j.create(new double[] {2.0, 4.0, 8.0}); @@ -5764,7 +6324,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testRDiv1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRDiv1(Nd4jBackend backend) { val argX = Nd4j.create(3).assign(2.0); val argY = Nd4j.create(new double[]{1.0, 2.0, 3.0}); val exp = Nd4j.create(new double[] {0.5, 1.0, 1.5}); @@ -5774,7 +6336,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testEqualOrder1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEqualOrder1(Nd4jBackend backend) { val array = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); val arrayC = array.dup('c'); val arrayF = array.dup('f'); @@ -5785,7 +6349,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatchTransform() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatchTransform(Nd4jBackend backend) { val array = Nd4j.create(new double[] {1, 1, 1, 0, 1, 1},'c'); val result = Nd4j.createUninitialized(DataType.BOOL, array.shape()); val exp = Nd4j.create(new boolean[] {false, false, false, true, false, false}); @@ -5797,7 +6363,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void test4DSumView() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test4DSumView(Nd4jBackend backend) { INDArray labels = Nd4j.linspace(1, 160, 160, DataType.DOUBLE).reshape(2, 5, 4, 4); //INDArray labels = Nd4j.linspace(1, 192, 192).reshape(new long[]{2, 6, 4, 4}); @@ -5823,7 +6391,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatMul1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatMul1(Nd4jBackend backend) { val x = 2; val A1 = 3; val A2 = 4; @@ -5835,7 +6405,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReduction_Z1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduction_Z1(Nd4jBackend backend) { val arrayX = Nd4j.create(10, 10, 10); val res = arrayX.max(1, 2); @@ -5844,7 +6416,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReduction_Z2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduction_Z2(Nd4jBackend backend) { val arrayX = Nd4j.create(10, 10); val res = arrayX.max(0); @@ -5853,7 +6427,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReduction_Z3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduction_Z3(Nd4jBackend backend) { val arrayX = Nd4j.create(200, 300); val res = arrayX.maxNumber().doubleValue(); @@ -5862,7 +6438,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSoftmaxZ1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftmaxZ1(Nd4jBackend backend) { val original = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10); val reference = original.dup(original.ordering()); val expected = original.dup(original.ordering()); @@ -5876,7 +6454,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testRDiv() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRDiv(Nd4jBackend backend) { val x = Nd4j.create(new double[]{2,2,2}); val y = Nd4j.create(new double[]{4,6,8}); val result = Nd4j.createUninitialized(DataType.DOUBLE, 3); @@ -5898,7 +6478,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testIm2Col() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2Col(Nd4jBackend backend) { int kY = 5; int kX = 5; int sY = 1; @@ -5939,7 +6521,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testGemmStrides() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemmStrides(Nd4jBackend backend) { // 4x5 matrix from arange(20) final INDArray X = Nd4j.arange(20).reshape(4,5); for (int i=0; i<5; i++){ @@ -5958,7 +6542,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testReshapeFailure() { + public void testReshapeFailure(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2); val b = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2); @@ -5971,7 +6555,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testScalar_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalar_1(Nd4jBackend backend) { val scalar = Nd4j.create(new float[]{2.0f}, new long[]{}); assertTrue(scalar.isScalar()); @@ -5985,7 +6571,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testScalar_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalar_2(Nd4jBackend backend) { val scalar = Nd4j.scalar(2.0f); val scalar2 = Nd4j.scalar(2.0f); val scalar3 = Nd4j.scalar(3.0f); @@ -6004,7 +6592,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testVector_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVector_1(Nd4jBackend backend) { val vector = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 5}); val vector2 = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 5}); val vector3 = Nd4j.createFromArray(new float[] {1, 2, 3, 4, 6}); @@ -6021,7 +6611,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testVectorScalar_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorScalar_2(Nd4jBackend backend) { val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5}); val scalar = Nd4j.scalar(2.0f); val exp = Nd4j.createFromArray(new float[]{3, 4, 5, 6, 7}); @@ -6032,7 +6624,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReshapeScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshapeScalar(Nd4jBackend backend) { val scalar = Nd4j.scalar(2.0f); val newShape = scalar.reshape(1, 1, 1, 1); @@ -6042,7 +6636,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testReshapeVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshapeVector(Nd4jBackend backend) { val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6}); val newShape = vector.reshape(3, 2); @@ -6051,7 +6647,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testTranspose1() { + public void testTranspose1(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6}); @@ -6066,7 +6662,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testTranspose2() { + public void testTranspose2(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { val scalar = Nd4j.scalar(2.f); @@ -6082,7 +6678,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test //@Disabled - public void testMatmul_128by256() { + public void testMatmul_128by256(Nd4jBackend backend) { val mA = Nd4j.create(128, 156).assign(1.0f); val mB = Nd4j.create(156, 256).assign(1.0f); @@ -6107,7 +6703,9 @@ public class Nd4jTestsC extends BaseNd4jTest { c = tf.matmul(a, b) */ @Test - public void testMatmul_Empty() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatmul_Empty(Nd4jBackend backend) { val mA = Nd4j.create(0,1); val mB = Nd4j.create(1,0); val mC = Nd4j.create(0,0); @@ -6122,7 +6720,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatmul_Empty1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatmul_Empty1(Nd4jBackend backend) { val mA = Nd4j.create(1,0, 4); val mB = Nd4j.create(1,4, 0); val mC = Nd4j.create(1,0, 0); @@ -6138,7 +6738,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testScalarSqueeze() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarSqueeze(Nd4jBackend backend) { val scalar = Nd4j.create(new float[]{2.0f}, new long[]{1, 1}); val output = Nd4j.scalar(0.0f); val exp = Nd4j.scalar(2.0f); @@ -6156,7 +6758,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testScalarVectorSqueeze() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarVectorSqueeze(Nd4jBackend backend) { val scalar = Nd4j.create(new float[]{2.0f}, new long[]{1}); assertArrayEquals(new long[]{1}, scalar.shape()); @@ -6177,7 +6781,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testVectorSqueeze() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorSqueeze(Nd4jBackend backend) { val vector = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6}, new long[]{1, 6}); val output = Nd4j.createFromArray(new float[] {0, 0, 0, 0, 0, 0}); val exp = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6}); @@ -6196,7 +6802,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatrixReshape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatrixReshape(Nd4jBackend backend) { val matrix = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {3, 3}); val exp = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}, new long[] {9}); @@ -6208,7 +6816,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testVectorScalarConcat() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorScalarConcat(Nd4jBackend backend) { val vector = Nd4j.createFromArray(new float[] {1, 2}); val scalar = Nd4j.scalar(3.0f); @@ -6232,7 +6842,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testScalarPrint_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarPrint_1(Nd4jBackend backend) { val scalar = Nd4j.scalar(3.0f); Nd4j.exec(new PrintVariable(scalar, true)); @@ -6240,7 +6852,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testValueArrayOf_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testValueArrayOf_1(Nd4jBackend backend) { val vector = Nd4j.valueArrayOf(new long[] {5}, 2f, DataType.FLOAT); val exp = Nd4j.createFromArray(new float[]{2, 2, 2, 2, 2}); @@ -6250,7 +6864,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testValueArrayOf_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testValueArrayOf_2(Nd4jBackend backend) { val scalar = Nd4j.valueArrayOf(new long[] {}, 2f); val exp = Nd4j.scalar(2f); @@ -6260,7 +6876,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testArrayCreation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArrayCreation(Nd4jBackend backend) { val vector = Nd4j.create(new float[]{1, 2, 3}, new long[] {3}, 'c'); val exp = Nd4j.createFromArray(new float[]{1, 2, 3}); @@ -6269,6 +6887,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testACosh(){ //http://www.wolframalpha.com/input/?i=acosh(x) @@ -6286,6 +6906,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCosh(){ //http://www.wolframalpha.com/input/?i=cosh(x) @@ -6303,6 +6925,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAtanh(){ //http://www.wolframalpha.com/input/?i=atanh(x) @@ -6321,6 +6945,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLastIndex(){ INDArray in = Nd4j.create(new double[][]{ @@ -6338,7 +6964,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testBadReduce3Call() { + public void testBadReduce3Call(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { val x = Nd4j.create(400,20); val y = Nd4j.ones(1, 20); @@ -6349,7 +6975,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testReduce3AlexBug() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduce3AlexBug(Nd4jBackend backend) { val arr = Nd4j.linspace(1,100,100, DataType.DOUBLE).reshape('f', 10, 10).dup('c'); val arr2 = Nd4j.linspace(1,100,100, DataType.DOUBLE).reshape('c', 10, 10); val out = Nd4j.getExecutioner().exec(new EuclideanDistance(arr, arr2, 1)); @@ -6359,7 +6987,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAllDistancesEdgeCase1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllDistancesEdgeCase1(Nd4jBackend backend) { val x = Nd4j.create(400, 20).assign(2.0); val y = Nd4j.ones(1, 20); val z = Transforms.allEuclideanDistances(x, y, 1); @@ -6370,7 +7000,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testConcat_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat_1(Nd4jBackend backend) { for(char order : new char[]{'c', 'f'}) { INDArray arr1 = Nd4j.create(new double[]{1, 2}, new long[]{1, 2}, order); @@ -6384,6 +7016,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRdiv() { final INDArray a = Nd4j.create(new double[]{2.0, 2.0, 2.0, 2.0}); final INDArray b = Nd4j.create(new double[]{1.0, 2.0, 4.0, 8.0}); @@ -6403,6 +7037,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRsub() { final INDArray a = Nd4j.create(new double[]{2.0, 2.0, 2.0, 2.0}); final INDArray b = Nd4j.create(new double[]{1.0, 2.0, 4.0, 8.0}); @@ -6423,7 +7059,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testHalfStuff() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHalfStuff(Nd4jBackend backend) { if (!Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) return; @@ -6442,6 +7080,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testInconsistentOutput(){ INDArray in = Nd4j.rand(1, 802816); INDArray W = Nd4j.rand(802816, 1); @@ -6455,7 +7095,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void test3D_create_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test3D_create_1(Nd4jBackend backend) { val jArray = new float[2][3][4]; fillJvmArray3D(jArray); @@ -6474,7 +7116,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void test4D_create_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test4D_create_1(Nd4jBackend backend) { val jArray = new float[2][3][4][5]; fillJvmArray4D(jArray); @@ -6492,7 +7136,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBroadcast_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcast_1(Nd4jBackend backend) { val array1 = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(5, 1, 2).broadcast(5, 4, 2); val array2 = Nd4j.linspace(1, 20, 20, DataType.DOUBLE).reshape(5, 4, 1).broadcast(5, 4, 2); val exp = Nd4j.create(new double[] {2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 6.0f, 8.0f, 9.0f, 9.0f, 10.0f, 10.0f, 11.0f, 11.0f, 12.0f, 14.0f, 15.0f, 15.0f, 16.0f, 16.0f, 17.0f, 17.0f, 18.0f, 20.0f, 21.0f, 21.0f, 22.0f, 22.0f, 23.0f, 23.0f, 24.0f, 26.0f, 27.0f, 27.0f, 28.0f, 28.0f, 29.0f, 29.0f, 30.0f}).reshape(5,4,2); @@ -6504,6 +7150,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAddiColumnEdge(){ INDArray arr1 = Nd4j.create(1, 5); arr1.addiColumnVector(Nd4j.ones(1)); @@ -6512,7 +7160,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testMmulViews_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmulViews_1(Nd4jBackend backend) { val arrayX = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3); val arrayA = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); @@ -6531,7 +7181,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTile_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTile_1(Nd4jBackend backend) { val array = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); val exp = Nd4j.create(new double[] {1.000000, 2.000000, 3.000000, 1.000000, 2.000000, 3.000000, 4.000000, 5.000000, 6.000000, 4.000000, 5.000000, 6.000000, 1.000000, 2.000000, 3.000000, 1.000000, 2.000000, 3.000000, 4.000000, 5.000000, 6.000000, 4.000000, 5.000000, 6.000000}, new int[] {4, 6}); val output = Nd4j.create(4, 6); @@ -6548,7 +7200,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testRelativeError_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRelativeError_1(Nd4jBackend backend) { val arrayX = Nd4j.create(10, 10); val arrayY = Nd4j.ones(10, 10); val exp = Nd4j.ones(10, 10); @@ -6559,11 +7213,15 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testBugMeshgridOnDoubleArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBugMeshgridOnDoubleArray(Nd4jBackend backend) { Nd4j.meshgrid(Nd4j.create(new double[] { 1, 2, 3 }), Nd4j.create(new double[] { 4, 5, 6 })); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMeshGrid(){ INDArray x1 = Nd4j.create(new double[]{1,2,3,4}).reshape(1, -1); @@ -6602,7 +7260,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testAccumuationWithoutAxis_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAccumuationWithoutAxis_1(Nd4jBackend backend) { val array = Nd4j.create(3, 3).assign(1.0); val result = array.sum(); @@ -6612,7 +7272,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSummaryStatsEquality_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSummaryStatsEquality_1(Nd4jBackend backend) { // log.info("Datatype: {}", Nd4j.dataType()); for(boolean biasCorrected : new boolean[]{false, true}) { @@ -6631,6 +7293,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMeanEdgeCase_C(){ INDArray arr = Nd4j.linspace(1, 30,30, DataType.DOUBLE).reshape(new int[]{3,10,1}).dup('c'); INDArray arr2 = arr.mean(2); @@ -6641,6 +7305,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMeanEdgeCase_F(){ INDArray arr = Nd4j.linspace(1, 30,30, DataType.DOUBLE).reshape(new int[]{3,10,1}).dup('f'); INDArray arr2 = arr.mean(2); @@ -6651,6 +7317,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMeanEdgeCase2_C(){ INDArray arr = Nd4j.linspace(1, 60,60, DataType.DOUBLE).reshape(new int[]{3,10,2}).dup('c'); INDArray arr2 = arr.mean(2); @@ -6664,6 +7332,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMeanEdgeCase2_F(){ INDArray arr = Nd4j.linspace(1, 60,60, DataType.DOUBLE).reshape(new int[]{3,10,2}).dup('f'); INDArray arr2 = arr.mean(2); @@ -6677,6 +7347,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLegacyDeserialization_1() throws Exception { val f = new ClassPathResource("legacy/NDArray_javacpp.bin").getFile(); @@ -6697,7 +7369,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testRndBloat16() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRndBloat16(Nd4jBackend backend) { INDArray x = Nd4j.rand(DataType.BFLOAT16 , 'c', new long[]{5}); assertTrue(x.sumNumber().floatValue() > 0); @@ -6706,6 +7380,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLegacyDeserialization_2() throws Exception { val f = new ClassPathResource("legacy/NDArray_longshape_float.bin").getFile(); @@ -6727,6 +7403,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLegacyDeserialization_3() throws Exception { val f = new ClassPathResource("legacy/NDArray_longshape_double.bin").getFile(); @@ -6747,7 +7425,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testTearPile_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTearPile_1(Nd4jBackend backend) { val source = Nd4j.rand(new int[]{10, 15}); val list = Nd4j.tear(source, 1); @@ -6762,7 +7442,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testVariance_4D_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariance_4D_1(Nd4jBackend backend) { val dtype = Nd4j.dataType(); Nd4j.setDataType(DataType.FLOAT); @@ -6778,6 +7460,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTranspose_Custom(){ INDArray arr = Nd4j.linspace(1,15, 15, DataType.DOUBLE).reshape(5,3); @@ -6795,6 +7479,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRowColumnOpsRank1(){ for( int i=0; i<6; i++ ) { @@ -6858,6 +7544,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEmptyShapeRank0(){ Nd4j.getRandom().setSeed(12345); int[] s = new int[0]; @@ -6894,7 +7582,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testScalarView_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarView_1(Nd4jBackend backend) { val array = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); val exp = Nd4j.create(new double[]{1.0, 2.0, 5.0, 4.0, 5.0}); val scalar = array.getScalar(2); @@ -6906,7 +7596,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testScalarView_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarView_2(Nd4jBackend backend) { val array = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); val exp = Nd4j.create(new double[]{1.0, 2.0, 5.0, 4.0}).reshape(2, 2); val scalar = array.getScalar(1, 0); @@ -6918,7 +7610,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSomething_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSomething_1(Nd4jBackend backend) { val arrayX = Nd4j.create(128, 128, 'f'); val arrayY = Nd4j.create(128, 128, 'f'); val arrayZ = Nd4j.create(128, 128, 'f'); @@ -6945,7 +7639,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testIndexesIteration_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIndexesIteration_1(Nd4jBackend backend) { val arrayC = Nd4j.linspace(1, 60, 60, DataType.DOUBLE).reshape(3, 4, 5); val arrayF = arrayC.dup('f'); @@ -6962,7 +7658,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testIndexesIteration_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIndexesIteration_2(Nd4jBackend backend) { val arrayC = Nd4j.linspace(1, 60, 60, DataType.DOUBLE).reshape(3, 4, 5); val arrayF = arrayC.dup('f'); @@ -6983,28 +7681,12 @@ public class Nd4jTestsC extends BaseNd4jTest { } } - @Test - @Disabled - public void testMatmul_vs_tf() throws Exception { - // uncomment this line to initialize & propagate sgemm/dgemm pointer - //Nd4j.getBlasWrapper().level3(); - - val arrayA = NodeReader.readArray("mnist_00", "input.placeholder"); - val arrayB = NodeReader.readArray("mnist_00", "Variable.0"); - val arrayC = Nd4j.create(100, 10); - val exp = NodeReader.readArray("mnist_00", "MatMul.0"); - val badExp = Nd4j.create(100, 10); - - Mmul op = new Mmul(arrayA, arrayB, arrayC, null); - Nd4j.getExecutioner().exec(op); - - assertEquals(exp, arrayC); - assertNotEquals(badExp, arrayC); - } @Test - public void testPairwiseScalar_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPairwiseScalar_1(Nd4jBackend backend) { val exp_1 = Nd4j.create(new double[]{2.0, 3.0, 4.0}, new long[]{3}); val exp_2 = Nd4j.create(new double[]{0.0, 1.0, 2.0}, new long[]{3}); val exp_3 = Nd4j.create(new double[]{1.0, 2.0, 3.0}, new long[]{3}); @@ -7025,7 +7707,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testLTOE_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLTOE_1(Nd4jBackend backend) { val x = Nd4j.create(new double[]{1.0, 2.0, 3.0, -1.0}); val y = Nd4j.create(new double[]{2.0, 2.0, 3.0, -2.0}); @@ -7042,7 +7726,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testGTOE_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGTOE_1(Nd4jBackend backend) { val x = Nd4j.create(new double[]{1.0, 2.0, 3.0, -1.0}); val y = Nd4j.create(new double[]{2.0, 2.0, 3.0, -2.0}); @@ -7075,6 +7761,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGet(){ //https://github.com/deeplearning4j/deeplearning4j/issues/6133 INDArray m = Nd4j.linspace(0,99,100, DataType.DOUBLE).reshape('c', 10,10); @@ -7099,6 +7787,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWhere1(){ INDArray arr = Nd4j.create(new boolean[][]{{false,true,false},{false,false,true},{false,false,true}}); @@ -7112,6 +7802,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWhere2(){ INDArray arr = Nd4j.create(DataType.BOOL, 3,3,3); @@ -7130,6 +7822,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWhere3(){ INDArray arr = Nd4j.create(new boolean[][]{{false,true,false},{false,false,true},{false,false,true}}); INDArray x = Nd4j.valueArrayOf(3, 3, 1.0); @@ -7146,6 +7840,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWhereEmpty(){ INDArray inArray = Nd4j.zeros(2, 3); inArray.putScalar(0, 0, 10.0f); @@ -7171,7 +7867,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testScalarEquality_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarEquality_1(Nd4jBackend backend) { val x = Nd4j.scalar(1.0f); val e = Nd4j.scalar(3.0f); @@ -7181,6 +7879,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testStack(){ INDArray in = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4); INDArray in2 = in.add(100); @@ -7209,6 +7909,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPutSpecifiedIndex(){ long[][] ss = new long[][]{{3,4}, {3,4,5}, {3,4,5,6}}; long[][] st = new long[][]{{4,4}, {4,4,5}, {4,4,5,6}}; @@ -7240,6 +7942,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPutSpecifiedIndices2d(){ INDArray arr = Nd4j.create(3,4); @@ -7258,6 +7962,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPutSpecifiedIndices3d(){ INDArray arr = Nd4j.create(2,3,4); @@ -7278,7 +7984,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testSpecifiedIndexArraySize1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSpecifiedIndexArraySize1(Nd4jBackend backend) { long[] shape = {2, 2, 2, 2}; INDArray in = Nd4j.create(shape); INDArrayIndex[] idx1 = new INDArrayIndex[]{NDArrayIndex.all(), new SpecifiedIndex(0), NDArrayIndex.all(), NDArrayIndex.all()}; @@ -7289,6 +7997,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTransposei(){ INDArray arr = Nd4j.linspace(1,12,12).reshape('c',3,4); @@ -7300,7 +8010,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testScatterUpdateShortcut() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterUpdateShortcut(Nd4jBackend backend) { val array = Nd4j.create(DataType.FLOAT, 5, 2); val updates = Nd4j.createFromArray(new float[][] {{1,1}, {2,2}, {3, 3}}); val indices = Nd4j.createFromArray(new int[]{1, 2, 3}); @@ -7313,7 +8025,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testScatterUpdateShortcut_f1() { + public void testScatterUpdateShortcut_f1(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { val array = Nd4j.create(DataType.FLOAT, 5, 2); val updates = Nd4j.createFromArray(new float[][] {{1,1}, {2,2}, {3, 3}}); @@ -7329,7 +8041,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testStatistics_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStatistics_1(Nd4jBackend backend) { val array = Nd4j.createFromArray(new float[] {-1.0f, 0.0f, 1.0f}); val stats = Nd4j.getExecutioner().inspectArray(array); @@ -7340,6 +8054,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testINDArrayMmulWithTranspose(){ Nd4j.getRandom().setSeed(12345); INDArray a = Nd4j.rand(2,5); @@ -7379,6 +8095,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testInvalidOrder(){ try { @@ -7432,6 +8150,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAssignValid(){ INDArray arr1 = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray arr2 = Nd4j.create(3,4); @@ -7440,6 +8160,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAssignInvalid(){ INDArray arr1 = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray arr2 = Nd4j.create(4,3); @@ -7452,6 +8174,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEmptyCasting(){ for(val from : DataType.values()) { if (from == DataType.UTF8 || from == DataType.UNKNOWN || from == DataType.COMPRESSED) @@ -7478,6 +8202,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testVStackRank1(){ List list = new ArrayList<>(); list.add(Nd4j.linspace(1,3,3, DataType.DOUBLE)); @@ -7493,6 +8219,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAxpyOpRows(){ INDArray arr = Nd4j.create(1,4).assign(2.0f); INDArray ones = Nd4j.ones(1,4).assign(3.0f); @@ -7505,12 +8233,16 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testEmptyArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyArray(Nd4jBackend backend) { INDArray empty = Nd4j.empty(DataType.INT); assertEquals(empty.toString(), "[]"); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLinspaceWithStep(){ double lower = -0.9, upper = 0.9, step = 0.2; @@ -7541,6 +8273,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLinspaceWithStepForIntegers(){ long lower = -9, upper = 9, step = 2; @@ -7571,7 +8305,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testArangeWithStep() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArangeWithStep(Nd4jBackend backend) { int begin = -9, end = 9, step = 2; INDArray in = Nd4j.arange(begin, end, step); assertEquals(in.getInt(0), -9); @@ -7586,7 +8322,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testRollingMean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRollingMean(Nd4jBackend backend) { val wsconf = WorkspaceConfiguration.builder() .initialSize(4L * (32*128*256*256 + 32*128 + 10*1024*1024)) .policyLearning(LearningPolicy.FIRST_LOOP) @@ -7620,11 +8358,15 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testZerosRank1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testZerosRank1(Nd4jBackend backend) { Nd4j.zeros(new int[] { 2 }, DataType.DOUBLE); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testReshapeEnforce(){ INDArray arr = Nd4j.create(new long[]{2,2}, 'c'); @@ -7644,6 +8386,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRepeatSimple(){ INDArray arr = Nd4j.createFromArray(new double[][]{ @@ -7667,6 +8411,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRowsEdgeCaseView(){ INDArray arr = Nd4j.linspace(0, 9, 10, DataType.DOUBLE).reshape('f', 5, 2).dup('c'); //0,1,2... along columns @@ -7681,7 +8427,7 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testPullRowsFailure() { + public void testPullRowsFailure(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { val idxs = new int[]{0,2,3,4}; val out = Nd4j.pullRows(Nd4j.createFromArray(0.0, 1.0, 2.0, 3.0, 4.0), 0, idxs); @@ -7690,7 +8436,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testRepeatStrided() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRepeatStrided(Nd4jBackend backend) { // Create a 2D array (shape 5x5) INDArray array = Nd4j.arange(25).reshape(5, 5); @@ -7709,7 +8457,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMeshgridDtypes() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMeshgridDtypes(Nd4jBackend backend) { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); Nd4j.meshgrid(Nd4j.create(new double[] { 1, 2, 3 }), Nd4j.create(new double[] { 4, 5, 6 })); @@ -7717,6 +8467,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGetColumnRowVector(){ INDArray arr = Nd4j.create(1,4); INDArray col = arr.getColumn(0); @@ -7726,6 +8478,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEmptyArrayReuse(){ //Empty arrays are immutable - no point creating them multiple times INDArray ef1 = Nd4j.empty(DataType.FLOAT); @@ -7738,6 +8492,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMaxViewF(){ INDArray arr = Nd4j.create(DataType.DOUBLE, new long[]{8,2}, 'f').assign(999); @@ -7749,6 +8505,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMin2(){ INDArray x = Nd4j.createFromArray(new double[][]{ {-999, 0.2236, 0.7973, 0.0962}, @@ -7778,18 +8536,18 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test() - public void testPutRowValidation() { - assertThrows(IllegalArgumentException.class,() -> { - val matrix = Nd4j.create(5, 10); - val row = Nd4j.create(25); + public void testPutRowValidation(Nd4jBackend backend) { + assertThrows(IllegalArgumentException.class,() -> { + val matrix = Nd4j.create(5, 10); + val row = Nd4j.create(25); - matrix.putRow(1, row); - }); + matrix.putRow(1, row); + }); } @Test() - public void testPutColumnValidation() { + public void testPutColumnValidation(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { val matrix = Nd4j.create(5, 10); val column = Nd4j.create(25); @@ -7800,6 +8558,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCreateF(){ char origOrder = Nd4j.order(); try { @@ -7833,6 +8593,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testReduceKeepDimsShape(){ INDArray arr = Nd4j.create(3,4); INDArray out = arr.sum(true, 1); @@ -7843,6 +8605,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSliceRow(){ double[] data = new double[]{15.0, 16.0}; INDArray vector = Nd4j.createFromArray(data).reshape(1,2); @@ -7854,6 +8618,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSliceMatrix(){ INDArray arr = Nd4j.arange(4).reshape(2,2); // System.out.println(arr.slice(0)); @@ -7864,6 +8630,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testScalarEq(){ INDArray scalarRank2 = Nd4j.scalar(10.0).reshape(1,1); INDArray scalarRank1 = Nd4j.scalar(10.0).reshape(1); @@ -7879,7 +8647,9 @@ public class Nd4jTestsC extends BaseNd4jTest { //@Disabled // https://github.com/eclipse/deeplearning4j/issues/7632 @Test - public void testGetWhereINDArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetWhereINDArray(Nd4jBackend backend) { INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 }); INDArray comp = Nd4j.create(new double[]{2, -3, 1, 1, -2, 1 }); INDArray expected = Nd4j.create(new double[] { 4, 8, 5 }); @@ -7889,7 +8659,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testGetWhereNumber() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetWhereNumber(Nd4jBackend backend) { INDArray input = Nd4j.create(new double[] { 1, -3, 4, 8, -2, 5 }); INDArray expected = Nd4j.create(new double[] { 8, 5 }); INDArray actual = input.getWhere(4, Conditions.greaterThan(1)); @@ -7898,6 +8670,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testType1(@TempDir Path testDir) throws IOException { for (int i = 0; i < 10; ++i) { INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100}); @@ -7919,6 +8693,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testOnes(){ INDArray arr = Nd4j.ones(); INDArray arr2 = Nd4j.ones(DataType.LONG); @@ -7929,6 +8705,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testZeros(){ INDArray arr = Nd4j.zeros(); INDArray arr2 = Nd4j.zeros(DataType.LONG); @@ -7939,6 +8717,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testType2(@TempDir Path testDir) throws IOException { for (int i = 0; i < 10; ++i) { INDArray in1 = Nd4j.ones(DataType.UINT16); @@ -7994,6 +8774,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testToXMatrix(){ List shapes = Arrays.asList(new long[]{3, 4}, new long[]{3, 1}, new long[]{1,3}); @@ -8023,6 +8805,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testToXVector(){ List shapes = Arrays.asList(new long[]{3}, new long[]{3, 1}, new long[]{1,3}); @@ -8053,6 +8837,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSumEdgeCase(){ INDArray row = Nd4j.create(1,3); INDArray sum = row.sum(0); @@ -8064,6 +8850,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMedianEdgeCase(){ INDArray rowVec = Nd4j.rand(DataType.FLOAT, 1, 10); INDArray median = rowVec.median(0); @@ -8083,7 +8871,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void mmulToScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void mmulToScalar(Nd4jBackend backend) { final INDArray arr1 = Nd4j.create(new float[] {1,2,3}).reshape(1,3); final INDArray arr2 = arr1.reshape(3,1); assertEquals( DataType.FLOAT, arr1.mmul(arr2).dataType(),"Incorrect type!"); @@ -8091,7 +8881,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testCreateDtypes() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateDtypes(Nd4jBackend backend) { int[] sliceShape = new int[] {9}; float[] arrays = new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; double [] arrays_double = new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; @@ -8105,6 +8897,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCreateShapeValidation(){ try { Nd4j.create(new double[]{1, 2, 3}, new int[]{1, 1}); @@ -8156,6 +8950,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBatchToSpace(){ INDArray out = Nd4j.create(DataType.FLOAT, 2, 4, 5); @@ -8177,6 +8973,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testToFromByteArray() throws IOException { // simple test to get rid of toByteArray and fromByteArray compiler warnings. INDArray x = Nd4j.arange(10); @@ -8194,7 +8992,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testVStackHStack1d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVStackHStack1d(Nd4jBackend backend) { INDArray rowVector1 = Nd4j.create(new double[]{1,2,3}); INDArray rowVector2 = Nd4j.create(new double[]{4,5,6}); @@ -8207,7 +9007,9 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test - public void testReduceAll_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduceAll_1(Nd4jBackend backend) { val x = Nd4j.empty(DataType.FLOAT); val e = Nd4j.scalar(true); val z = Nd4j.exec(new All(x)); @@ -8216,7 +9018,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReduceAll_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduceAll_2(Nd4jBackend backend) { val x = Nd4j.ones(DataType.FLOAT, 0); val e = Nd4j.scalar(true); val z = Nd4j.exec(new All(x)); @@ -8225,7 +9029,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testReduceAll_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduceAll_3(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 0); assertEquals(1, x.rank()); @@ -8236,6 +9042,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testScalarEqualsNoResult(){ INDArray out = Nd4j.exec(new ScalarEquals(Nd4j.createFromArray(-2, -1, 0, 1, 2), null, 0)); INDArray exp = Nd4j.createFromArray(false, false, true, false, false); @@ -8243,6 +9051,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPutOverwrite(){ INDArray arr = Nd4j.create(DataType.DOUBLE, 10); arr.putScalar(0, 10); @@ -8254,6 +9064,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEmptyReshapingMinus1(){ INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0); INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2); @@ -8268,7 +9080,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testConv2DWeightsFormat1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv2DWeightsFormat1(Nd4jBackend backend) { int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, pW = 0, dH = 1, dW = 1; int oH=2,oW=2; // Weights format tip : @@ -8300,7 +9114,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testConv2DWeightsFormat2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConv2DWeightsFormat2(Nd4jBackend backend) { int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; int oH=4,oW=3; WeightsFormat format = WeightsFormat.OYXI; @@ -8330,7 +9146,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatmulMethod_8() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatmulMethod_8(Nd4jBackend backend) { val x = Nd4j.create(DataType.INT8, 3, 5).assign(1); val y = Nd4j.create(DataType.INT8, 5, 3).assign(1); val e = Nd4j.create(DataType.INT8, 3, 3).assign(5); @@ -8340,7 +9158,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatmulMethod_7() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatmulMethod_7(Nd4jBackend backend) { val x = Nd4j.create(DataType.INT16, 3, 5).assign(1); val y = Nd4j.create(DataType.INT16, 5, 3).assign(1); val e = Nd4j.create(DataType.INT16, 3, 3).assign(5); @@ -8350,7 +9170,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatmulMethod_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatmulMethod_1(Nd4jBackend backend) { val x = Nd4j.create(DataType.INT32, 3, 5).assign(1); val y = Nd4j.create(DataType.INT32, 5, 3).assign(1); val e = Nd4j.create(DataType.INT32, 3, 3).assign(5); @@ -8360,7 +9182,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatmulMethod_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatmulMethod_2(Nd4jBackend backend) { val x = Nd4j.create(DataType.INT64, 3, 5).assign(1); val y = Nd4j.create(DataType.INT64, 5, 3).assign(1); val e = Nd4j.create(DataType.INT64, 3, 3).assign(5); @@ -8370,7 +9194,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatmulMethod_6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatmulMethod_6(Nd4jBackend backend) { val x = Nd4j.create(DataType.UINT8, 3, 5).assign(1); val y = Nd4j.create(DataType.UINT8, 5, 3).assign(1); val e = Nd4j.create(DataType.UINT8, 3, 3).assign(5); @@ -8380,7 +9206,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatmulMethod_5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatmulMethod_5(Nd4jBackend backend) { val x = Nd4j.create(DataType.UINT16, 3, 5).assign(1); val y = Nd4j.create(DataType.UINT16, 5, 3).assign(1); val e = Nd4j.create(DataType.UINT16, 3, 3).assign(5); @@ -8390,7 +9218,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatmulMethod_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatmulMethod_3(Nd4jBackend backend) { val x = Nd4j.create(DataType.UINT32, 3, 5).assign(1); val y = Nd4j.create(DataType.UINT32, 5, 3).assign(1); val e = Nd4j.create(DataType.UINT32, 3, 3).assign(5); @@ -8400,7 +9230,9 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test - public void testMatmulMethod_4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatmulMethod_4(Nd4jBackend backend) { val x = Nd4j.create(DataType.UINT64, 3, 5).assign(1); val y = Nd4j.create(DataType.UINT64, 5, 3).assign(1); val e = Nd4j.create(DataType.UINT64, 3, 3).assign(5); @@ -8410,6 +9242,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCreateBufferFromByteBuffer(){ for(DataType dt : DataType.values()){ @@ -8437,6 +9271,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCreateBufferFromByteBufferViews(){ for(DataType dt : DataType.values()){ @@ -8462,6 +9298,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTypeCastingToString(){ for(DataType dt : DataType.values()) { @@ -8480,6 +9318,8 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testShape0Casts(){ for(DataType dt : DataType.values()){ if(!dt.isNumerical()) @@ -8499,6 +9339,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSmallSort(){ INDArray arr = Nd4j.createFromArray(0.5, 0.4, 0.1, 0.2); INDArray expected = Nd4j.createFromArray(0.1, 0.2, 0.4, 0.5); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java index 52bb00738..f6ebb4b57 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonC.java @@ -23,8 +23,9 @@ package org.nd4j.linalg; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -42,18 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertTrue; -@RunWith(Parameterized.class) -public class Nd4jTestsComparisonC extends BaseNd4jTest { + +public class Nd4jTestsComparisonC extends BaseNd4jTestWithBackends { private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonC.class); public static final int SEED = 123; - DataType initialType; + DataType initialType = Nd4j.dataType(); - public Nd4jTestsComparisonC(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } @BeforeEach @@ -73,7 +70,9 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest { @Test - public void testGemmWithOpsCommonsMath() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemmWithOpsCommonsMath(Nd4jBackend backend) { List> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE); List> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE); @@ -140,13 +139,13 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest { private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair first, - Pair second) { + Pair second) { return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")"; } private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha, - double beta, Pair first, Pair second) { + double beta, Pair first, Pair second) { return i + "," + j + " - gemm(tA=" + transposeA + ",tB=" + transposeB + ",alpha=" + alpha + ",beta=" + beta - + "). A=" + first.getSecond() + ", B=" + second.getSecond(); + + "). A=" + first.getSecond() + ", B=" + second.getSecond(); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java index a45cebc75..0be72945b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java @@ -25,8 +25,9 @@ import org.apache.commons.math3.linear.RealMatrix; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -43,18 +44,14 @@ import java.util.Random; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class Nd4jTestsComparisonFortran extends BaseNd4jTest { + +public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends { private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonFortran.class); public static final int SEED = 123; - DataType initialType; + DataType initialType = Nd4j.dataType(); - public Nd4jTestsComparisonFortran(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } @BeforeEach @@ -75,7 +72,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { } @Test - public void testCrash() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCrash(Nd4jBackend backend) { INDArray array3d = Nd4j.ones(1, 10, 10); Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0); Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 1); @@ -85,7 +84,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { } @Test - public void testMmulWithOpsCommonsMath() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmulWithOpsCommonsMath(Nd4jBackend backend) { List> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE); @@ -100,7 +101,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { } @Test - public void testGemmWithOpsCommonsMath() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemmWithOpsCommonsMath(Nd4jBackend backend) { List> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE); List> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE); @@ -156,7 +159,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { } @Test - public void testGemvApacheCommons() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemvApacheCommons(Nd4jBackend backend) { int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8}; int[] colsArr = new int[] {2, 1, 10, 2, 1, 10}; @@ -197,7 +202,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { assertArrayEquals(new long[] {rows, 1}, gemv.shape()); assertArrayEquals(new int[] {rows, 1}, - new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()}); + new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()}); //Check entries: for (int r = 0; r < rows; r++) { @@ -211,7 +216,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { } @Test - public void testAddSubtractWithOpsCommonsMath() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddSubtractWithOpsCommonsMath(Nd4jBackend backend) { List> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); for (int i = 0; i < first.size(); i++) { @@ -229,7 +236,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { } @Test - public void testMulDivOnCheckUtilMatrices() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMulDivOnCheckUtilMatrices(Nd4jBackend backend) { List> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); List> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE); for (int i = 0; i < first.size(); i++) { @@ -245,13 +254,13 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest { } private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair first, - Pair second) { + Pair second) { return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")"; } private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha, - double beta, Pair first, Pair second) { + double beta, Pair first, Pair second) { return i + "," + j + " - gemm(tA=" + transposeA + ",tB= " + transposeB + ",alpha=" + alpha + ",beta= " + beta - + "). A=" + first.getSecond() + ", B=" + second.getSecond(); + + "). A=" + first.getSecond() + ", B=" + second.getSecond(); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsF.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsF.java index 20c783031..8837e89a2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsF.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsF.java @@ -23,8 +23,9 @@ package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -36,18 +37,15 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@RunWith(Parameterized.class) -public class Nd4jTestsF extends BaseNd4jTest { - DataType initialType; +public class Nd4jTestsF extends BaseNd4jTestWithBackends { - public Nd4jTestsF(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } + DataType initialType = Nd4j.dataType(); @Test - public void testConcat3D_Vstack_F() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat3D_Vstack_F(Nd4jBackend backend) { //Nd4j.getExecutioner().enableVerboseMode(true); //Nd4j.getExecutioner().enableDebugMode(true); @@ -79,7 +77,9 @@ public class Nd4jTestsF extends BaseNd4jTest { @Test - public void testSlice_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSlice_1(Nd4jBackend backend) { val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1); val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1}); val exp1 = Nd4j.create(new double[]{2, 4}, new int[] {2, 1}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java index e31f9fbf8..5e7813b8d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ShufflesTests.java @@ -22,8 +22,9 @@ package org.nd4j.linalg; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -34,15 +35,13 @@ import java.util.*; import static junit.framework.TestCase.assertTrue; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class ShufflesTests extends BaseNd4jTest { - public ShufflesTests(Nd4jBackend backend) { - super(backend); - } +public class ShufflesTests extends BaseNd4jTestWithBackends { @Test - public void testSimpleShuffle1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSimpleShuffle1(Nd4jBackend backend) { INDArray array = Nd4j.zeros(10, 10); for (int x = 0; x < 10; x++) { array.getRow(x).assign(x); @@ -64,7 +63,9 @@ public class ShufflesTests extends BaseNd4jTest { } @Test - public void testSimpleShuffle2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSimpleShuffle2(Nd4jBackend backend) { INDArray array = Nd4j.zeros(10, 10); for (int x = 0; x < 10; x++) { array.getColumn(x).assign(x); @@ -79,7 +80,9 @@ public class ShufflesTests extends BaseNd4jTest { } @Test - public void testSimpleShuffle3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSimpleShuffle3(Nd4jBackend backend) { INDArray array = Nd4j.zeros(11, 10); for (int x = 0; x < 11; x++) { array.getRow(x).assign(x); @@ -95,7 +98,9 @@ public class ShufflesTests extends BaseNd4jTest { } @Test - public void testSymmetricShuffle1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSymmetricShuffle1(Nd4jBackend backend) { INDArray features = Nd4j.zeros(10, 10); INDArray labels = Nd4j.zeros(10, 3); for (int x = 0; x < 10; x++) { @@ -133,7 +138,9 @@ public class ShufflesTests extends BaseNd4jTest { } @Test - public void testSymmetricShuffle2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSymmetricShuffle2(Nd4jBackend backend) { INDArray features = Nd4j.zeros(10, 10, 20); INDArray labels = Nd4j.zeros(10, 10, 3); @@ -171,7 +178,9 @@ public class ShufflesTests extends BaseNd4jTest { } @Test - public void testSymmetricShuffle3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSymmetricShuffle3(Nd4jBackend backend) { INDArray features = Nd4j.zeros(10, 10, 20); INDArray featuresMask = Nd4j.zeros(10, 20); INDArray labels = Nd4j.zeros(10, 10, 3); @@ -236,7 +245,9 @@ public class ShufflesTests extends BaseNd4jTest { * @throws Exception */ @Test - public void testHalfVectors1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHalfVectors1(Nd4jBackend backend) { int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20); int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20); @@ -257,7 +268,9 @@ public class ShufflesTests extends BaseNd4jTest { } @Test - public void testInterleavedVector1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInterleavedVector1(Nd4jBackend backend) { int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20); int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20); @@ -278,7 +291,9 @@ public class ShufflesTests extends BaseNd4jTest { } @Test - public void testInterleavedVector3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInterleavedVector3(Nd4jBackend backend) { for (int e = 0; e < 1000; e++) { int length = e + 256; //RandomUtils.nextInt(121, 2073); int[] array1 = ArrayUtil.buildInterleavedVector(new Random(System.currentTimeMillis()), length); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/TestEigen.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/TestEigen.java index ef0ac7afe..f1b1a6b36 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/TestEigen.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/TestEigen.java @@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.eigen.Eigen; @@ -35,16 +36,11 @@ import org.nd4j.common.util.ArrayUtil; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) + @Slf4j -public class TestEigen extends BaseNd4jTest { +public class TestEigen extends BaseNd4jTestWithBackends { - protected DataType initialType; - - public TestEigen(Nd4jBackend backend) { - super(backend); - initialType = Nd4j.dataType(); - } + protected DataType initialType = Nd4j.dataType(); @BeforeEach public void before() { @@ -59,7 +55,9 @@ public class TestEigen extends BaseNd4jTest { // test of functions added by Luke Czapla // Compares solution of A x = L x to solution to A x = L B x when it is simple @Test - public void test2Syev() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test2Syev(Nd4jBackend backend) { for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { Nd4j.setDefaultDataTypes(dt, dt); @@ -78,7 +76,9 @@ public class TestEigen extends BaseNd4jTest { } @Test - public void testSyev() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSyev(Nd4jBackend backend) { for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { //log.info("Datatype: {}", dt); Nd4j.setDefaultDataTypes(dt, dt); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java index cbd99c8cb..747ea39ab 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java @@ -24,23 +24,23 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.common.util.ArrayUtil; -@RunWith(Parameterized.class) + @Slf4j -public class ToStringTest extends BaseNd4jTest { - public ToStringTest(Nd4jBackend backend) { - super(backend); - } +public class ToStringTest extends BaseNd4jTestWithBackends { @Test - public void testToString() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToString(Nd4jBackend backend) throws Exception { assertEquals("[ 1, 2, 3]", Nd4j.createFromArray(1, 2, 3).toString()); @@ -58,6 +58,8 @@ public class ToStringTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testToStringScalars(){ DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32}; String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"}; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java index 97a8270d4..4b0455305 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.activations; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.activations.impl.ActivationCube; import org.nd4j.linalg.activations.impl.ActivationELU; import org.nd4j.linalg.activations.impl.ActivationGELU; @@ -55,12 +56,9 @@ import java.util.List; import static junit.framework.TestCase.assertTrue; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) -public class TestActivation extends BaseNd4jTest { - public TestActivation(Nd4jBackend backend) { - super(backend); - } +public class TestActivation extends BaseNd4jTestWithBackends { + @Override public char ordering() { @@ -79,7 +77,9 @@ public class TestActivation extends BaseNd4jTest { } @Test - public void testRelu(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRelu(Nd4jBackend backend){ Double[] max = {null, 6.0, 2.5, 5.0}; Double[] threshold = {0.0, 0.0, 0.75, 0.2}; @@ -131,30 +131,32 @@ public class TestActivation extends BaseNd4jTest { } @Test - public void testJson() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJson(Nd4jBackend backend) throws Exception { IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25), - new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(), - new ActivationLReLU(0.25), new ActivationRationalTanh(), new ActivationReLU(), - new ActivationRReLU(0.25, 0.5), new ActivationSigmoid(), new ActivationSoftmax(), - new ActivationSoftPlus(), new ActivationSoftSign(), new ActivationTanH(), new ActivationGELU(), new ActivationGELU(true)}; + new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(), + new ActivationLReLU(0.25), new ActivationRationalTanh(), new ActivationReLU(), + new ActivationRReLU(0.25, 0.5), new ActivationSigmoid(), new ActivationSoftmax(), + new ActivationSoftPlus(), new ActivationSoftSign(), new ActivationTanH(), new ActivationGELU(), new ActivationGELU(true)}; String[][] expectedFields = new String[][] {{"@class"}, //Cube - {"@class", "alpha"}, //ELU - {"@class"}, //Hard sigmoid - {"@class"}, //Hard TanH - {"@class"}, //Identity - {"@class", "alpha"}, //Leaky Relu - {"@class"}, //rational tanh - {"@class", "max", "negativeSlope", "threshold"}, //relu - {"@class", "l", "u"}, //rrelu - {"@class"}, //sigmoid - {"@class"}, //Softmax - {"@class"}, //Softplus - {"@class"}, //Softsign - {"@class"}, //Tanh - {"@class", "precise"}, //GELU - {"@class", "precise"} //GELU precise + {"@class", "alpha"}, //ELU + {"@class"}, //Hard sigmoid + {"@class"}, //Hard TanH + {"@class"}, //Identity + {"@class", "alpha"}, //Leaky Relu + {"@class"}, //rational tanh + {"@class", "max", "negativeSlope", "threshold"}, //relu + {"@class", "l", "u"}, //rrelu + {"@class"}, //sigmoid + {"@class"}, //Softmax + {"@class"}, //Softplus + {"@class"}, //Softsign + {"@class"}, //Tanh + {"@class", "precise"}, //GELU + {"@class", "precise"} //GELU precise }; @@ -172,7 +174,7 @@ public class TestActivation extends BaseNd4jTest { String[] expFields = expectedFields[i]; String msg = activations[i].toString() + "\tExpected fields: " + Arrays.toString(expFields) - + "\tActual fields: " + actualFieldsByName; + + "\tActual fields: " + actualFieldsByName; assertEquals(expFields.length, actualFieldsByName.size(),msg); for (String s : expFields) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestBackend.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestBackend.java index a2229aa0f..64e5d4924 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestBackend.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestBackend.java @@ -20,21 +20,20 @@ package org.nd4j.linalg.api; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; -import org.nd4j.linalg.factory.Environment; -import org.nd4j.linalg.factory.Nd4j; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertFalse; -public class TestBackend extends BaseNd4jTest { +public class TestBackend extends BaseNd4jTestWithBackends { - public TestBackend(Nd4jBackend backend) { - super(backend); - } - @Test - public void TestBuildInfo(){ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBuildInfo(Nd4jBackend backend){ System.out.println("Backend build info: " + backend.buildInfo()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestEnvironment.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestEnvironment.java index 8ee444adf..1eb61c4f1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestEnvironment.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestEnvironment.java @@ -20,26 +20,27 @@ package org.nd4j.linalg.api; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.factory.Environment; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertFalse; -public class TestEnvironment extends BaseNd4jTest { +public class TestEnvironment extends BaseNd4jTestWithBackends { - public TestEnvironment(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { return 'c'; } - @Test - public void testEnvironment(){ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEnvironment(Nd4jBackend backend){ Environment e = Nd4j.getEnvironment(); System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion()); System.out.println("CPU: " + e.isCPU()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java index 1a3ce86f6..4eb25d221 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java @@ -26,7 +26,9 @@ import org.bytedeco.javacpp.FloatPointer; import org.bytedeco.javacpp.Pointer; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -40,16 +42,12 @@ import java.util.Map; import static org.junit.jupiter.api.Assertions.*; @Slf4j -public class TestNDArrayCreation extends BaseNd4jTest { - - - public TestNDArrayCreation(Nd4jBackend backend) { - super(backend); - } +public class TestNDArrayCreation extends BaseNd4jTestWithBackends { @Test - @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") - public void testBufferCreation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBufferCreation(Nd4jBackend backend) { DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2}); Pointer pointer = dataBuffer.pointer(); FloatPointer floatPointer = new FloatPointer(pointer); @@ -69,6 +67,8 @@ public class TestNDArrayCreation extends BaseNd4jTest { @Test @Disabled + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCreateNpy() throws Exception { INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile()); assertEquals(2, arrCreate.size(0)); @@ -82,7 +82,9 @@ public class TestNDArrayCreation extends BaseNd4jTest { @Test @Disabled - public void testCreateNpz() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateNpz(Nd4jBackend backend) throws Exception { Map map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile()); assertEquals(true, map.containsKey("x")); assertEquals(true, map.containsKey("y")); @@ -100,8 +102,7 @@ public class TestNDArrayCreation extends BaseNd4jTest { } @Test - @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") - public void testCreateNpy3() throws Exception { + public void testCreateNpy3(Nd4jBackend backend) throws Exception { INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile()); assertEquals(8, arrCreate.length()); assertEquals(3, arrCreate.rank()); @@ -113,7 +114,7 @@ public class TestNDArrayCreation extends BaseNd4jTest { @Test @Disabled // this is endless test - public void testEndlessAllocation() { + public void testEndlessAllocation(Nd4jBackend backend) { Nd4j.getEnvironment().setMaxSpecialMemory(1); while (true) { val arr = Nd4j.createUninitialized(DataType.FLOAT, 100000000); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java index 9d6dc2988..4f7823622 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreationUtil.java @@ -21,24 +21,23 @@ package org.nd4j.linalg.api; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; -import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.common.primitives.Pair; import org.nd4j.common.util.ArrayUtil; import static org.junit.jupiter.api.Assertions.assertArrayEquals; -public class TestNDArrayCreationUtil extends BaseNd4jTest { +public class TestNDArrayCreationUtil extends BaseNd4jTestWithBackends { - public TestNDArrayCreationUtil(Nd4jBackend backend) { - super(backend); - } - @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testShapes() { long[] shape2d = {2, 3}; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java index 3e8990d63..836a3d5eb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java @@ -21,20 +21,21 @@ package org.nd4j.linalg.api; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -public class TestNamespaces extends BaseNd4jTest { +public class TestNamespaces extends BaseNd4jTestWithBackends { - public TestNamespaces(Nd4jBackend backend) { - super(backend); - } @Test - public void testBitwiseSimple(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBitwiseSimple(Nd4jBackend backend){ INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT); INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT); @@ -50,7 +51,9 @@ public class TestNamespaces extends BaseNd4jTest { } @Test - public void testMathSimple(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMathSimple(Nd4jBackend backend) { INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1); INDArray abs = Nd4j.math.abs(x); // System.out.println(x); @@ -65,7 +68,9 @@ public class TestNamespaces extends BaseNd4jTest { } @Test - public void testRandomSimple(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRandomSimple(Nd4jBackend backend){ INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10); // System.out.println(normal); INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10); @@ -73,7 +78,9 @@ public class TestNamespaces extends BaseNd4jTest { } @Test - public void testNeuralNetworkSimple(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNeuralNetworkSimple(Nd4jBackend backend){ INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10)); // System.out.println(out); INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java index 6d34fec58..bb569f928 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/LapackTest.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.api.blas; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -31,15 +32,14 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) -public class LapackTest extends BaseNd4jTest { - public LapackTest(Nd4jBackend backend) { - super(backend); - } + +public class LapackTest extends BaseNd4jTestWithBackends { @Test - public void testQRSquare() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testQRSquare(Nd4jBackend backend) { INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9}); A = A.reshape('c', 3, 3); INDArray O = Nd4j.create(A.dataType(), A.shape()); @@ -57,7 +57,9 @@ public class LapackTest extends BaseNd4jTest { } @Test - public void testQRRect() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testQRRect(Nd4jBackend backend) { INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); A = A.reshape('f', 4, 3); INDArray O = Nd4j.create(A.dataType(), A.shape()); @@ -75,7 +77,9 @@ public class LapackTest extends BaseNd4jTest { } @Test - public void testCholeskyL() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCholeskyL(Nd4jBackend backend) { INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,}); A = A.reshape('c', 3, 3); INDArray O = Nd4j.create(A.dataType(), A.shape()); @@ -92,7 +96,9 @@ public class LapackTest extends BaseNd4jTest { } @Test - public void testCholeskyU() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCholeskyU(Nd4jBackend backend) { INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,}); A = A.reshape('f', 3, 3); INDArray O = Nd4j.create(A.dataType(), A.shape()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java index 466af9744..b9ed7c336 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.api.blas; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -35,14 +36,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class Level1Test extends BaseNd4jTest { - public Level1Test(Nd4jBackend backend) { - super(backend); - } + +public class Level1Test extends BaseNd4jTestWithBackends { @Test - public void testDot() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDot(Nd4jBackend backend) { INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4}); assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1); @@ -55,7 +55,9 @@ public class Level1Test extends BaseNd4jTest { } @Test - public void testAxpy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAxpy(Nd4jBackend backend) { INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray row = matrix.getRow(1); Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row); @@ -64,7 +66,9 @@ public class Level1Test extends BaseNd4jTest { } @Test - public void testAxpy2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAxpy2(Nd4jBackend backend) { val rowX = Nd4j.create(new double[]{1, 2, 3, 4}); val rowY = Nd4j.create(new double[]{1, 2, 3, 4}); val exp = Nd4j.create(new double[]{3, 6, 9, 12}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level2Test.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level2Test.java index 3cab5d94a..9c22b88a9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level2Test.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level2Test.java @@ -21,23 +21,23 @@ package org.nd4j.linalg.api.blas; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) -public class Level2Test extends BaseNd4jTest { - public Level2Test(Nd4jBackend backend) { - super(backend); - } + +public class Level2Test extends BaseNd4jTestWithBackends { @Test - public void testGemv1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemv1(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); @@ -51,7 +51,9 @@ public class Level2Test extends BaseNd4jTest { } @Test - public void testGemv2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemv2(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); @@ -65,7 +67,9 @@ public class Level2Test extends BaseNd4jTest { } @Test - public void testGemv3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemv3(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); @@ -79,7 +83,9 @@ public class Level2Test extends BaseNd4jTest { } @Test - public void testGemv4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemv4(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); @@ -93,7 +99,9 @@ public class Level2Test extends BaseNd4jTest { } @Test - public void testGemv5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemv5(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); @@ -109,7 +117,9 @@ public class Level2Test extends BaseNd4jTest { } @Test - public void testGemv6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemv6(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); @@ -125,7 +135,9 @@ public class Level2Test extends BaseNd4jTest { } @Test - public void testGemv7() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemv7(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java index c26b3e9fb..80d9b0896 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level3Test.java @@ -21,23 +21,23 @@ package org.nd4j.linalg.api.blas; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) -public class Level3Test extends BaseNd4jTest { - public Level3Test(Nd4jBackend backend) { - super(backend); - } + +public class Level3Test extends BaseNd4jTestWithBackends { @Test - public void testGemm1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemm1(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1); @@ -47,7 +47,9 @@ public class Level3Test extends BaseNd4jTest { } @Test - public void testGemm2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemm2(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100); INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1); @@ -57,7 +59,9 @@ public class Level3Test extends BaseNd4jTest { } @Test - public void testGemm3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemm3(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); @@ -75,7 +79,9 @@ public class Level3Test extends BaseNd4jTest { } @Test - public void testGemm4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemm4(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10); @@ -92,7 +98,9 @@ public class Level3Test extends BaseNd4jTest { } @Test - public void testGemm5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemm5(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10); @@ -106,7 +114,9 @@ public class Level3Test extends BaseNd4jTest { } @Test - public void testGemm6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemm6(Nd4jBackend backend) { INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100); INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/params/ParamsTestsF.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/params/ParamsTestsF.java index 24c8a8ea8..605d318fe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/params/ParamsTestsF.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/params/ParamsTestsF.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.api.blas.params; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -33,16 +34,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class ParamsTestsF extends BaseNd4jTest { - - public ParamsTestsF(Nd4jBackend backend) { - super(backend); - } +public class ParamsTestsF extends BaseNd4jTestWithBackends { @Test - public void testGemm() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemm (Nd4jBackend backend) { INDArray a = Nd4j.create(2, 2); INDArray b = Nd4j.create(2, 3); INDArray c = Nd4j.create(2, 3); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java index 30f426216..442de77db 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataBufferTests.java @@ -25,9 +25,10 @@ import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.*; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; @@ -45,16 +46,15 @@ import java.nio.ByteOrder; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class DataBufferTests extends BaseNd4jTest { - public DataBufferTests(Nd4jBackend backend) { - super(backend); - } +public class DataBufferTests extends BaseNd4jTestWithBackends { + @Test @Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657") - public void testNoArgCreateBufferFromArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNoArgCreateBufferFromArray(Nd4jBackend backend) { //Tests here: //1. Create from JVM array @@ -280,7 +280,9 @@ public class DataBufferTests extends BaseNd4jTest { @Test - public void testCreateTypedBuffer() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateTypedBuffer(Nd4jBackend backend) { WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); @@ -350,7 +352,9 @@ public class DataBufferTests extends BaseNd4jTest { } @Test - public void testAsBytes() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAsBytes(Nd4jBackend backend) { INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1); for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16, @@ -404,7 +408,9 @@ public class DataBufferTests extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEnsureLocation(){ //https://github.com/eclipse/deeplearning4j/issues/8783 Nd4j.create(1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java index 1719ce084..1668deda3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.api.buffer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -33,13 +34,10 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertThrows; -@RunWith(Parameterized.class) -public class DataTypeValidationTests extends BaseNd4jTest { - DataType initialType; - public DataTypeValidationTests(Nd4jBackend backend) { - super(backend); - } +public class DataTypeValidationTests extends BaseNd4jTestWithBackends { + DataType initialType = Nd4j.dataType(); + @BeforeEach public void setUp() { @@ -48,7 +46,7 @@ public class DataTypeValidationTests extends BaseNd4jTest { } @AfterEach - public void shutUp() { + public void reset() { Nd4j.setDataType(initialType); } @@ -73,7 +71,9 @@ public class DataTypeValidationTests extends BaseNd4jTest { * Testing level1 blas */ @Test() - public void testBlasValidation1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasValidation1(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { INDArray x = Nd4j.create(10); @@ -90,7 +90,9 @@ public class DataTypeValidationTests extends BaseNd4jTest { * Testing level2 blas */ @Test() - public void testBlasValidation2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasValidation2(Nd4jBackend backend) { assertThrows(RuntimeException.class,() -> { INDArray a = Nd4j.create(100, 10); INDArray x = Nd4j.create(100); @@ -108,7 +110,9 @@ public class DataTypeValidationTests extends BaseNd4jTest { * Testing level3 blas */ @Test() - public void testBlasValidation3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasValidation3(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { INDArray x = Nd4j.create(100, 100); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java index ccaa1f4d1..58ac518f8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DoubleDataBufferTest.java @@ -26,9 +26,10 @@ import org.bytedeco.javacpp.indexer.Indexer; import org.junit.jupiter.api.*; import org.junit.jupiter.api.io.TempDir; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; @@ -54,34 +55,31 @@ import static org.junit.jupiter.api.Assertions.*; * * @author Adam Gibson */ -@RunWith(Parameterized.class) + @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") -public class DoubleDataBufferTest extends BaseNd4jTest { +public class DoubleDataBufferTest extends BaseNd4jTestWithBackends { - DataType initialType; - - public DoubleDataBufferTest(Nd4jBackend backend) { - super(backend); - initialType = Nd4j.dataType(); - } + DataType initialType = Nd4j.dataType(); @BeforeEach - public void before() { + public void before(Nd4jBackend backend) { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); } @AfterEach - public void after() { + public void after(Nd4jBackend backend) { DataTypeUtil.setDTypeForContext(initialType); } @Test - public void testPointerCreation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPointerCreation(Nd4jBackend backend) { DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4); Indexer indexer = DoubleIndexer.create(floatPointer); DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.DOUBLE, 4, indexer); @@ -89,8 +87,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest { assertArrayEquals(other.asDouble(), buffer.asDouble(), 0.001); } - @Test - public void testGetSet() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetSet(Nd4jBackend backend) { double[] d1 = new double[] {1, 2, 3, 4}; DataBuffer d = Nd4j.createBuffer(d1); double[] d2 = d.asDouble(); @@ -100,10 +100,12 @@ public class DoubleDataBufferTest extends BaseNd4jTest { - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSerialization2() throws Exception { INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10), - // Nd4j.ones(5,10).getRow(2) + // Nd4j.ones(5,10).getRow(2) }; for (INDArray a : arr) { @@ -128,7 +130,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSerialization(@TempDir Path testDir) throws Exception { File dir = testDir.toFile(); DataBuffer buf = Nd4j.createBuffer(5); @@ -150,8 +154,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest { } - @Test - public void testDup() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDup(Nd4jBackend backend) { double[] d1 = new double[] {1, 2, 3, 4}; DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d2 = d.dup(); @@ -160,8 +166,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest { - @Test - public void testPut() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPut(Nd4jBackend backend) { double[] d1 = new double[] {1, 2, 3, 4}; DataBuffer d = Nd4j.createBuffer(d1); d.put(0, 0.0); @@ -171,8 +179,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest { } - @Test - public void testGetRange() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRange(Nd4jBackend backend) { DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data(); double[] get = buffer.getDoublesAt(0, 3); double[] data = new double[] {1, 2, 3}; @@ -186,8 +196,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest { } - @Test - public void testGetOffsetRange() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetOffsetRange(Nd4jBackend backend) { DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data(); double[] get = buffer.getDoublesAt(1, 3); double[] data = new double[] {2, 3, 4}; @@ -201,8 +213,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest { } - @Test - public void testAssign() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssign(Nd4jBackend backend) { DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); DataBuffer one = Nd4j.createBuffer(new double[] {1}); DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3}); @@ -212,8 +226,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest { } - @Test - public void testOffset() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOffset(Nd4jBackend backend) { DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2); assertEquals(2, create.length()); assertEquals(0, create.offset()); @@ -222,8 +238,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest { } - @Test - public void testReallocation() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReallocation(Nd4jBackend backend) { DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4}); assertEquals(4, buffer.capacity()); double[] old = buffer.asDouble(); @@ -232,10 +250,12 @@ public class DoubleDataBufferTest extends BaseNd4jTest { assertArrayEquals(old, Arrays.copyOf(buffer.asDouble(), 4), 1e-1); } - @Test - public void testReallocationWorkspace() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReallocationWorkspace(Nd4jBackend backend) { WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) - .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); + .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4}); @@ -249,7 +269,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAddressPointer(){ if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){ return; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java index 1dce2d107..d37aca6d6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java @@ -27,7 +27,9 @@ import org.bytedeco.javacpp.indexer.Indexer; import org.junit.jupiter.api.*; import org.junit.jupiter.api.io.TempDir; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; @@ -54,14 +56,9 @@ import static org.junit.jupiter.api.Assertions.*; * @author Adam Gibson */ @Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") -public class FloatDataBufferTest extends BaseNd4jTest { +public class FloatDataBufferTest extends BaseNd4jTestWithBackends { - DataType initialType; - - public FloatDataBufferTest(Nd4jBackend backend) { - super(backend); - initialType = Nd4j.dataType(); - } + DataType initialType = Nd4j.dataType(); @BeforeEach public void before() { @@ -76,7 +73,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { @Test - public void testPointerCreation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPointerCreation(Nd4jBackend backend) { FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4); Indexer indexer = FloatIndexer.create(floatPointer); DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.FLOAT, 4, indexer); @@ -85,7 +84,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { } @Test - public void testGetSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetSet(Nd4jBackend backend) { float[] d1 = new float[] {1, 2, 3, 4}; DataBuffer d = Nd4j.createBuffer(d1); float[] d2 = d.asFloat(); @@ -96,7 +97,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { @Test - public void testSerialization(@TempDir Path tempDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSerialization(@TempDir Path tempDir,Nd4jBackend backend) throws Exception { File dir = tempDir.toFile(); DataBuffer buf = Nd4j.createBuffer(5); String fileName = "buf.ser"; @@ -117,7 +120,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { } @Test - public void testDup() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDup(Nd4jBackend backend) { float[] d1 = new float[] {1, 2, 3, 4}; DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d2 = d.dup(); @@ -125,7 +130,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { } @Test - public void testToNio() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToNio(Nd4jBackend backend) { DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT); assertEquals(4, buff.length()); if (buff.allocationMode() == DataBuffer.AllocationMode.HEAP) @@ -137,7 +144,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { } @Test - public void testPut() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPut(Nd4jBackend backend) { float[] d1 = new float[] {1, 2, 3, 4}; DataBuffer d = Nd4j.createBuffer(d1); d.put(0, 0.0); @@ -148,7 +157,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { @Test - public void testGetRange() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRange(Nd4jBackend backend) { DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); float[] get = buffer.getFloatsAt(0, 3); float[] data = new float[] {1, 2, 3}; @@ -164,7 +175,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { @Test - public void testGetOffsetRange() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetOffsetRange(Nd4jBackend backend) { DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); float[] get = buffer.getFloatsAt(1, 3); float[] data = new float[] {2, 3, 4}; @@ -181,7 +194,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { @Test - public void testAsBytes() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAsBytes(Nd4jBackend backend) { INDArray arr = Nd4j.create(5); byte[] d = arr.data().asBytes(); assertEquals(4 * 5, d.length,getFailureMessage()); @@ -191,7 +206,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { } @Test - public void testAssign() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssign(Nd4jBackend backend) { DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); DataBuffer one = Nd4j.createBuffer(new double[] {1}); DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3}); @@ -201,7 +218,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { } @Test - public void testReadWrite() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReadWrite(Nd4jBackend backend) throws Exception { DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3}); ByteArrayOutputStream bos = new ByteArrayOutputStream(); DataOutputStream dos = new DataOutputStream(bos); @@ -215,7 +234,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { } @Test - public void testOffset() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOffset(Nd4jBackend backend) { DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2); assertEquals(2, create.length()); assertEquals(0, create.offset()); @@ -225,7 +246,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { } @Test - public void testReallocation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReallocation(Nd4jBackend backend) { DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4}); assertEquals(4, buffer.capacity()); float[] old = buffer.asFloat(); @@ -236,7 +259,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { } @Test - public void testReallocationWorkspace() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReallocationWorkspace(Nd4jBackend backend) { WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); @@ -253,7 +278,9 @@ public class FloatDataBufferTest extends BaseNd4jTest { } @Test - public void testAddressPointer(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddressPointer(Nd4jBackend backend){ if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){ return; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java index af3f277f8..1dccbb338 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java @@ -23,7 +23,9 @@ package org.nd4j.linalg.api.buffer; import lombok.val; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.AllocationPolicy; @@ -37,13 +39,12 @@ import java.util.Arrays; import static org.junit.jupiter.api.Assertions.*; -public class IntDataBufferTests extends BaseNd4jTest { +public class IntDataBufferTests extends BaseNd4jTestWithBackends { - public IntDataBufferTests(Nd4jBackend backend) { - super(backend); - } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBasicSerde1() throws Exception { @@ -82,7 +83,9 @@ public class IntDataBufferTests extends BaseNd4jTest { */ @Test - public void testReallocation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReallocation(Nd4jBackend backend) { DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4}); assertEquals(4, buffer.capacity()); buffer.reallocate(6); @@ -94,9 +97,11 @@ public class IntDataBufferTests extends BaseNd4jTest { } @Test - public void testReallocationWorkspace() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReallocationWorkspace(Nd4jBackend backend) { WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) - .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); + .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java index 0f984c9d5..1d4cc1123 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.api.indexing; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -37,17 +38,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class IndexingTests extends BaseNd4jTest { +public class IndexingTests extends BaseNd4jTestWithBackends { - public IndexingTests(Nd4jBackend backend) { - super(backend); - } @Test - public void testINDArrayIndexingEqualToRank() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testINDArrayIndexingEqualToRank(Nd4jBackend backend) { INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE); INDArray indexes = Nd4j.create(new double[][]{ {0,1,2}, @@ -62,7 +61,9 @@ public class IndexingTests extends BaseNd4jTest { @Test - public void testINDArrayIndexingLessThanRankSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testINDArrayIndexingLessThanRankSimple(Nd4jBackend backend) { INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE); INDArray indexes = Nd4j.create(new double[][]{ {0}, @@ -76,7 +77,9 @@ public class IndexingTests extends BaseNd4jTest { @Test - public void testINDArrayIndexingLessThanRankFourDimension() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testINDArrayIndexingLessThanRankFourDimension(Nd4jBackend backend) { INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE); INDArray indexes = Nd4j.create(new double[][]{ {0},{1} @@ -89,7 +92,9 @@ public class IndexingTests extends BaseNd4jTest { } @Test - public void testPutSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutSimple(Nd4jBackend backend) { INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2); INDArray indexes = Nd4j.create(new double[][]{ {0},{1} @@ -101,7 +106,9 @@ public class IndexingTests extends BaseNd4jTest { } @Test - public void testGetScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetScalar(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray d = arr.get(NDArrayIndex.point(1)); assertTrue(d.isScalar()); @@ -110,14 +117,18 @@ public class IndexingTests extends BaseNd4jTest { } @Test - public void testNewAxis() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewAxis(Nd4jBackend backend) { INDArray arr = Nd4j.rand(new int[] {4, 2, 3}); INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1)); // System.out.println(view); } @Test - public void testVectorIndexing() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorIndexing(Nd4jBackend backend) { INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE); int[] index = new int[] {5, 8, 9}; INDArray columnsTest = x.getColumns(index); @@ -129,7 +140,9 @@ public class IndexingTests extends BaseNd4jTest { } @Test - public void testGetRowsColumnsMatrix() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRowsColumnsMatrix(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6); INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}}); @@ -147,7 +160,9 @@ public class IndexingTests extends BaseNd4jTest { @Test - public void testSlicing() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSlicing(Nd4jBackend backend) { INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE); INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14}); INDArray slice1Test = arange.slice(1); @@ -155,7 +170,9 @@ public class IndexingTests extends BaseNd4jTest { } @Test - public void testArangeMul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArangeMul(Nd4jBackend backend) { INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE); INDArrayIndex index = NDArrayIndex.interval(0, 2); INDArray get = arange.get(index, index); @@ -167,7 +184,9 @@ public class IndexingTests extends BaseNd4jTest { } @Test - public void testGetIndicesVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetIndicesVector(Nd4jBackend backend) { INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1); INDArray test = Nd4j.create(new double[] {2, 3}); INDArray result = line.get(NDArrayIndex.point(0), NDArrayIndex.interval(1, 3)); @@ -175,7 +194,9 @@ public class IndexingTests extends BaseNd4jTest { } @Test - public void testGetIndicesVectorView() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetIndicesVectorView(Nd4jBackend backend) { INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5); INDArray column = matrix.getColumn(0).reshape(1,5); INDArray test = Nd4j.create(new double[] {6, 11}); @@ -193,7 +214,9 @@ public class IndexingTests extends BaseNd4jTest { } @Test - public void test2dGetPoint(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test2dGetPoint(Nd4jBackend backend){ INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4); for( int i=0; i<3; i++ ){ INDArray exp = Nd4j.create(new double[]{i*4+1, i*4+2, i*4+3, i*4+4}); @@ -206,7 +229,7 @@ public class IndexingTests extends BaseNd4jTest { assertEquals(exp, get); } - for( int i=0; i<4; i++ ){ + for( int i = 0; i < 4; i++) { INDArray exp = Nd4j.create(new double[]{1+i, 5+i, 9+i}); INDArray col = arr.getColumn(i); INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(i)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java index 2639b2048..b9f361df3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java @@ -21,10 +21,11 @@ package org.nd4j.linalg.api.indexing; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.common.base.Preconditions; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; @@ -49,16 +50,15 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class IndexingTestsC extends BaseNd4jTest { + +public class IndexingTestsC extends BaseNd4jTestWithBackends { - public IndexingTestsC(Nd4jBackend backend) { - super(backend); - } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNegativeBounds() { INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5); INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1)); @@ -70,7 +70,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(assertion,get); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNewAxis() { INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all()); @@ -79,7 +81,9 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void broadcastBug() { INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2}); final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0)); @@ -90,7 +94,9 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIntervalsIn3D() { INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE); INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2); @@ -99,7 +105,9 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSmallInterval() { INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE); INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2); @@ -108,7 +116,9 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAllWithNewAxisAndInterval() { INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3); @@ -117,7 +127,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(assertion2, get2); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAllWithNewAxisInMiddle() { INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3); @@ -126,7 +138,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(assertion2, get2); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAllWithNewAxis() { INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray get = arr.get(newAxis(), all(), point(1)); @@ -136,7 +150,9 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIndexingWithMmul() { INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1); @@ -147,7 +163,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(assertion, c); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPointPointInterval() { INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3); INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3)); @@ -156,7 +174,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(assertion, get); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIntervalLowerBound() { INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2)); @@ -167,7 +187,9 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGetPointRowVector() { INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1); @@ -177,7 +199,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSpecifiedIndexVector() { INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4); INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); @@ -194,7 +218,9 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPutRowIndexing() { INDArray arr = Nd4j.ones(1, 10); INDArray row = Nd4j.create(1, 10); @@ -204,7 +230,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(arr, row); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testVectorIndexing2() { INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true)); INDArray assertion = Nd4j.create(new double[] {2, 4}); @@ -219,7 +247,9 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testOffsetsC() { INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); assertEquals(3, NDArrayIndex.offset(arr, 1, 1)); @@ -235,7 +265,9 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIndexFor() { long[] shape = {1, 2}; INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape); @@ -244,7 +276,9 @@ public class IndexingTestsC extends BaseNd4jTest { } } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGetScalar() { INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray d = arr.get(point(1)); @@ -253,7 +287,9 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testVectorIndexing() { INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1); INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5}); @@ -261,14 +297,18 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(assertion, viewTest); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNegativeIndices() { INDArray test = Nd4j.create(10, 10, 10); test.putScalar(new int[] {0, 0, -1}, 1.0); assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber()); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGetIndices2d() { INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2); INDArray firstRow = twoByTwo.getRow(0); @@ -286,7 +326,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGetRow() { Nd4j.getRandom().setSeed(12345); INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5); @@ -303,7 +345,9 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGetRowEdgeCase() { INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1); INDArray get = rowVec.getRow(0); //Returning shape [1,1] @@ -312,7 +356,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(rowVec, get); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGetColumnEdgeCase() { INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose(); INDArray get = colVec.getColumn(0); //Returning shape [1,1] @@ -321,7 +367,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(colVec, get); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testConcatColumns() { INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE); INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE); @@ -330,7 +378,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(assertion, concat); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGetIndicesVector() { INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1); INDArray test = Nd4j.create(new double[] {2, 3}); @@ -338,7 +388,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(test, result); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testArangeMul() { INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE); INDArrayIndex index = interval(0, 2); @@ -349,7 +401,9 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(assertion, mul); } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIndexingThorough(){ long[] fullShape = {3,4,5,6,7}; @@ -549,7 +603,9 @@ public class IndexingTestsC extends BaseNd4jTest { return d; } - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void debugging(){ long[] inShape = {3,4}; INDArrayIndex[] indexes = new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(1, 2, 4)}; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java index 721e5925e..1177d0a4a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/resolve/NDArrayIndexResolveTests.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.api.indexing.resolve; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -36,15 +37,14 @@ import static org.junit.jupiter.api.Assertions.*; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class NDArrayIndexResolveTests extends BaseNd4jTest { - public NDArrayIndexResolveTests(Nd4jBackend backend) { - super(backend); - } +public class NDArrayIndexResolveTests extends BaseNd4jTestWithBackends { + @Test - public void testResolvePoint() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testResolvePoint(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArrayIndex[] test = NDArrayIndex.resolve(arr.shape(), NDArrayIndex.point(1)); INDArrayIndex[] assertion = {NDArrayIndex.point(1), NDArrayIndex.all()}; @@ -59,6 +59,8 @@ public class NDArrayIndexResolveTests extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testResolvePointVector() { INDArray arr = Nd4j.linspace(1, 4, 4); INDArrayIndex[] getPoint = {NDArrayIndex.point(1)}; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java index 923911f20..db08ba1db 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.api.indexing.shape; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.Indices; @@ -34,19 +35,15 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class IndexShapeTests extends BaseNd4jTest { - - public IndexShapeTests(Nd4jBackend backend) { - super(backend); - } - +public class IndexShapeTests extends BaseNd4jTestWithBackends { private int[] shape = {1, 1, 2, 1, 3, 4, 5, 1}; @Test - public void testSinglePoint() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSinglePoint(Nd4jBackend backend) { /* Assumes all indexes are filled out. Test simple general point case @@ -77,7 +74,9 @@ public class IndexShapeTests extends BaseNd4jTest { } @Test - public void testInterval() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInterval(Nd4jBackend backend) { int[] basicAssertion = {1, 1, 1, 1, 3, 1, 2, 1}; INDArrayIndex[] basicTest = {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 1), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(1, 2), @@ -88,7 +87,9 @@ public class IndexShapeTests extends BaseNd4jTest { @Test - public void testNewAxis() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewAxis(Nd4jBackend backend) { //normal prepend int[] prependAssertion = {1, 1, 1, 1, 2, 1, 3, 4, 5, 1}; INDArrayIndex[] prependTest = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(), diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java index b70af316e..cd81c5aa1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/shape/IndexShapeTests2d.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.api.indexing.shape; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.Indices; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -33,25 +34,26 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class IndexShapeTests2d extends BaseNd4jTest { - public IndexShapeTests2d(Nd4jBackend backend) { - super(backend); - } +public class IndexShapeTests2d extends BaseNd4jTestWithBackends { + private long[] shape = {3, 2}; @Test - public void test2dCases() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test2dCases(Nd4jBackend backend) { assertArrayEquals(new long[] {1, 2}, Indices.shape(shape, NDArrayIndex.point(1))); assertArrayEquals(new long[] {3, 1}, Indices.shape(shape, NDArrayIndex.all(), NDArrayIndex.point(1))); } @Test - public void testNewAxis2d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewAxis2d(Nd4jBackend backend) { assertArrayEquals(new long[] {1, 3, 2}, Indices.shape(shape, NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all())); assertArrayEquals(new long[] {3, 1, 2}, Indices.shape(shape, diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java index 5c4eebb1a..c93c159f8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/iterator/NDIndexIteratorTest.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.api.iterator; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.factory.Nd4jBackend; @@ -33,15 +34,14 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class NDIndexIteratorTest extends BaseNd4jTest { - public NDIndexIteratorTest(Nd4jBackend backend) { - super(backend); - } +public class NDIndexIteratorTest extends BaseNd4jTestWithBackends { + @Test - public void testIterate() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIterate(Nd4jBackend backend) { val shapeIter = new NdIndexIterator(2, 2); val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},}; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java index d74759bc0..bc2859129 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java @@ -28,9 +28,10 @@ import org.apache.commons.lang3.ArrayUtils; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -45,18 +46,15 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@RunWith(Parameterized.class) -public class TestNdArrReadWriteTxt extends BaseNd4jTest { - - public TestNdArrReadWriteTxt(Nd4jBackend backend) { - super(backend); - } +public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends { @Test - public void compareAfterWrite(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception { int [] ranksToCheck = new int[] {0,1,2,3,4}; - for (int i=0; i p : list) { INDArray arr = p.getFirst().assign(testValues); @@ -256,7 +261,9 @@ public class TestTensorAlongDimension extends BaseNd4jTest { } @Test - public void testTadKnownValues() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTadKnownValues(Nd4jBackend backend) { long[] shape = {2, 3, 4}; INDArray arr = Nd4j.create(DataType.DOUBLE, shape); @@ -277,7 +284,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { INDArray exp12_0 = Nd4j.create(new double[][] {{0, 1, 2, 3}, {10, 11, 12, 13}, {20, 21, 22, 23}}); INDArray exp12_1 = - Nd4j.create(new double[][] {{100, 101, 102, 103}, {110, 111, 112, 113}, {120, 121, 122, 123}}); + Nd4j.create(new double[][] {{100, 101, 102, 103}, {110, 111, 112, 113}, {120, 121, 122, 123}}); assertEquals(exp01_0, arr.tensorAlongDimension(0, 0, 1)); assertEquals(exp01_0, arr.tensorAlongDimension(0, 1, 0)); @@ -296,7 +303,9 @@ public class TestTensorAlongDimension extends BaseNd4jTest { } @Test - public void testStalled() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStalled(Nd4jBackend backend) { int shape[] = new int[] {3, 3, 4, 5}; INDArray orig2 = Nd4j.create(shape, 'c'); System.out.println("Shape: " + Arrays.toString(orig2.shapeInfoDataBuffer().asInt())); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java index a4bb53bd1..23a1a93cb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/blas/BlasTests.java @@ -25,9 +25,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -39,15 +40,14 @@ import java.util.Collections; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class BlasTests extends BaseNd4jTest { - public BlasTests(Nd4jBackend backend) { - super(backend); - } +public class BlasTests extends BaseNd4jTestWithBackends { + @Test - public void simpleTest() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void simpleTest(Nd4jBackend backend) { INDArray m1 = Nd4j.create(new double[][]{{1.0}, {2.0}, {3.0}, {4.0}}); m1 = m1.reshape(2, 2); @@ -77,7 +77,9 @@ public class BlasTests extends BaseNd4jTest { @Test - public void testGemmInvalid1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemmInvalid1(Nd4jBackend backend) { final INDArray a = Nd4j.rand(3, 4); final INDArray b = Nd4j.rand(4, 5); @@ -93,7 +95,9 @@ public class BlasTests extends BaseNd4jTest { } @Test - public void testGemmInvalid3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemmInvalid3(Nd4jBackend backend) { final INDArray a = Nd4j.rand(4, 3); final INDArray b = Nd4j.rand(4, 5); @@ -109,7 +113,9 @@ public class BlasTests extends BaseNd4jTest { } @Test - public void testGemm1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemm1(Nd4jBackend backend) { final INDArray a = Nd4j.rand(4, 3); final INDArray b = Nd4j.rand(4, 5); @@ -120,7 +126,9 @@ public class BlasTests extends BaseNd4jTest { } @Test - public void testGemm2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemm2(Nd4jBackend backend) { final INDArray a = Nd4j.rand(4, 3); final INDArray b = Nd4j.rand(4, 5); @@ -135,7 +143,9 @@ public class BlasTests extends BaseNd4jTest { } @Test - public void testGemm3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGemm3(Nd4jBackend backend) { final INDArray a = Nd4j.rand(4, 3); final INDArray b = Nd4j.rand(4, 5); @@ -151,7 +161,9 @@ public class BlasTests extends BaseNd4jTest { @Test - public void testMmuli1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmuli1(Nd4jBackend backend) { final INDArray activations = Nd4j.createUninitialized(new long[]{1, 3, 1}, 'f'); final INDArray z = activations.tensorAlongDimension(0, 1, 2); @@ -165,7 +177,9 @@ public class BlasTests extends BaseNd4jTest { } @Test - public void testMmuli2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmuli2(Nd4jBackend backend) { final INDArray activations = Nd4j.createUninitialized(new long[]{2, 3, 1}, 'f'); final INDArray z = activations.tensorAlongDimension(0, 1, 2); @@ -179,7 +193,9 @@ public class BlasTests extends BaseNd4jTest { } @Test - public void testMmuli3(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmuli3(Nd4jBackend backend){ final INDArray activations = Nd4j.createUninitialized(new long[]{1, 3, 2}, 'f'); final INDArray z = activations.tensorAlongDimension(0, 1, 2); @@ -192,7 +208,9 @@ public class BlasTests extends BaseNd4jTest { } @Test - public void test_Fp16_Mmuli_1(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test_Fp16_Mmuli_1(Nd4jBackend backend){ final INDArray activations = Nd4j.createUninitialized(DataType.HALF, new long[]{1, 3, 2}, 'f'); final INDArray z = activations.tensorAlongDimension(0, 1, 2); @@ -205,7 +223,9 @@ public class BlasTests extends BaseNd4jTest { } @Test - public void test_Fp16_Mmuli_2(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test_Fp16_Mmuli_2(Nd4jBackend backend){ val a = Nd4j.create(DataType.HALF, 32, 768); val b = Nd4j.create(DataType.HALF, 768); @@ -214,7 +234,9 @@ public class BlasTests extends BaseNd4jTest { @Test @Disabled - public void testHalfPrecision() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHalfPrecision(Nd4jBackend backend) { val a = Nd4j.create(DataType.HALF, 64, 768); val b = Nd4j.create(DataType.HALF, 768, 1024); val c = Nd4j.create(DataType.HALF, new long[]{64, 1024}, 'f'); @@ -234,7 +256,9 @@ public class BlasTests extends BaseNd4jTest { } @Test - public void testMmuli4(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmuli4(Nd4jBackend backend){ try { Nd4j.rand(1, 3).mmuli(Nd4j.rand(3, 1), Nd4j.createUninitialized(new int[]{10, 10, 1})); fail("Expected exception"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java index 5eb2357bb..911a1f31b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java @@ -22,18 +22,17 @@ package org.nd4j.linalg.broadcast; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -42,14 +41,13 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -@RunWith(Parameterized.class) -public class BasicBroadcastTests extends BaseNd4jTest { - public BasicBroadcastTests(Nd4jBackend backend) { - super(backend); - } + +public class BasicBroadcastTests extends BaseNd4jTestWithBackends { @Test - public void basicBroadcastTest_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastTest_1(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 5); val y = Nd4j.createFromArray(new float[]{1.f, 1.f, 1.f, 1.f, 1.f}); val e = Nd4j.create(DataType.FLOAT, 3, 5).assign(1.f); @@ -63,7 +61,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void basicBroadcastTest_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastTest_2(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2); val y = Nd4j.createFromArray(new float[]{1.f, 1.f, 1.f, 1.f}).reshape(2, 2); val e = Nd4j.create(DataType.FLOAT, 3, 2, 2).assign(1.f); @@ -78,7 +78,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { @Test - public void basicBroadcastTest_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastTest_3(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(1); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val e = Nd4j.create(DataType.FLOAT, 3, 2, 2).assign(2.f); @@ -89,7 +91,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void basicBroadcastTest_4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastTest_4(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val e = Nd4j.create(DataType.FLOAT, 3, 2, 2).assign(2.f); @@ -100,7 +104,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void basicBroadcastTest_5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastTest_5(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val e = Nd4j.create(DataType.FLOAT, 3, 2, 2).assign(2.f); @@ -111,7 +117,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void basicBroadcastTest_6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastTest_6(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val e = Nd4j.create(DataType.FLOAT, 3, 2, 2).assign(-2.f); @@ -122,7 +130,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void basicBroadcastTest_7() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastTest_7(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val e = Nd4j.create(DataType.BOOL, 3, 2, 2).assign(false); @@ -133,7 +143,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test() - public void basicBroadcastFailureTest_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastFailureTest_1(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); @@ -142,7 +154,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test() - public void basicBroadcastFailureTest_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastFailureTest_2(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); @@ -152,7 +166,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test() - public void basicBroadcastFailureTest_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastFailureTest_3(Nd4jBackend backend) { assertThrows(IllegalStateException.class, () -> { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); @@ -162,14 +178,18 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test() - public void basicBroadcastFailureTest_4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastFailureTest_4(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val z = x.addi(y); } @Test() - public void basicBroadcastFailureTest_5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastFailureTest_5(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); @@ -179,7 +199,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test() - public void basicBroadcastFailureTest_6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastFailureTest_6(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); @@ -189,7 +211,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void basicBroadcastTest_8() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastTest_8(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val e = Nd4j.create(DataType.BOOL, 3, 2, 2).assign(true); @@ -200,7 +224,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void basicBroadcastTest_9() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastTest_9(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(2.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val e = Nd4j.create(DataType.BOOL, 3, 2, 2).assign(true); @@ -211,7 +237,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void basicBroadcastTest_10() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void basicBroadcastTest_10(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(1.f); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val e = Nd4j.create(DataType.BOOL, 3, 2, 2).assign(false); @@ -222,7 +250,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void emptyBroadcastTest_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void emptyBroadcastTest_1(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 1, 2); val y = Nd4j.create(DataType.FLOAT, 0, 2); @@ -231,7 +261,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test() - public void emptyBroadcastTest_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void emptyBroadcastTest_2(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 1, 2); val y = Nd4j.create(DataType.FLOAT, 0, 2); @@ -241,7 +273,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void emptyBroadcastTest_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void emptyBroadcastTest_3(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 1, 0, 1); val y = Nd4j.create(DataType.FLOAT, 1, 0, 2); @@ -253,7 +287,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { @Test - public void testValidInvalidBroadcast(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testValidInvalidBroadcast(Nd4jBackend backend){ INDArray x = Nd4j.rand(3,1); INDArray y = Nd4j.create(3, 4); @@ -313,7 +349,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void testLt(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLt(Nd4jBackend backend){ INDArray x = Nd4j.scalar(0); INDArray y = Nd4j.createFromArray(2,1,2); @@ -325,7 +363,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void testAdd(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdd(Nd4jBackend backend){ INDArray x = Nd4j.scalar(0); INDArray y = Nd4j.createFromArray(2,1,2); @@ -337,7 +377,9 @@ public class BasicBroadcastTests extends BaseNd4jTest { } @Test - public void testBroadcatableBool_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcatableBool_1(Nd4jBackend backend) { val op = DynamicCustomOp.builder("greater_equal") .addInputs(Nd4j.create(DataType.FLOAT, 3), Nd4j.create(DataType.FLOAT, 3)) .build(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java index e0c76eac6..a625425c5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.compression; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,19 +33,18 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class CompressionMagicTests extends BaseNd4jTest { - public CompressionMagicTests(Nd4jBackend backend) { - super(backend); - } + +public class CompressionMagicTests extends BaseNd4jTestWithBackends { @BeforeEach - public void setUp() { + public void setUp(Nd4jBackend backend) { } @Test - public void testMagicDecompression1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMagicDecompression1(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 100, 2500, DataType.FLOAT); INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP"); @@ -57,7 +57,9 @@ public class CompressionMagicTests extends BaseNd4jTest { } @Test - public void testMagicDecompression4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMagicDecompression4(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 100, 2500, DataType.FLOAT); INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP"); @@ -71,7 +73,9 @@ public class CompressionMagicTests extends BaseNd4jTest { } @Test - public void testDupSkipDecompression1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDupSkipDecompression1(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 100, 2500, DataType.FLOAT); INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP"); @@ -87,7 +91,9 @@ public class CompressionMagicTests extends BaseNd4jTest { } @Test - public void testDupSkipDecompression2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDupSkipDecompression2(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 100, 2500, DataType.FLOAT); INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP"); @@ -103,7 +109,9 @@ public class CompressionMagicTests extends BaseNd4jTest { } @Test - public void testDupSkipDecompression3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDupSkipDecompression3(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 100, 2500, DataType.FLOAT); INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionPerformanceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionPerformanceTests.java index fd271faa0..6eb0e9dc5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionPerformanceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionPerformanceTests.java @@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -37,16 +38,15 @@ import java.io.ByteArrayOutputStream; @Slf4j @Disabled -@RunWith(Parameterized.class) -public class CompressionPerformanceTests extends BaseNd4jTest { - public CompressionPerformanceTests(Nd4jBackend backend) { - super(backend); - } +public class CompressionPerformanceTests extends BaseNd4jTestWithBackends { + @Test - public void groundTruthTests_Threshold_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void groundTruthTests_Threshold_1(Nd4jBackend backend) { Nd4j.getRandom().setSeed(119); val params = Nd4j.rand(new long[]{1, 50000000}, -1.0, 1.0, Nd4j.getRandom()); val original = params.dup(params.ordering()); @@ -88,7 +88,9 @@ public class CompressionPerformanceTests extends BaseNd4jTest { } @Test - public void groundTruthTests_Bitmap_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void groundTruthTests_Bitmap_1(Nd4jBackend backend) { Nd4j.getRandom().setSeed(119); val params = Nd4j.rand(new long[]{1, 25000000}, -1.0, 1.0, Nd4j.getRandom()); val original = params.dup(params.ordering()); @@ -115,7 +117,7 @@ public class CompressionPerformanceTests extends BaseNd4jTest { log.info("Encoding time: {} ms;", time / iterations); } - @Override + @Override public char ordering() { return 'c'; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionSerDeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionSerDeTests.java index 535db0317..f495cbfee 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionSerDeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionSerDeTests.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.compression; import org.apache.commons.io.output.ByteArrayOutputStream; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -34,15 +35,14 @@ import java.io.ByteArrayInputStream; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) -public class CompressionSerDeTests extends BaseNd4jTest { - public CompressionSerDeTests(Nd4jBackend backend) { - super(backend); - } + +public class CompressionSerDeTests extends BaseNd4jTestWithBackends { @Test - public void testAutoDecompression2() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAutoDecompression2(Nd4jBackend backend) throws Exception { INDArray array = Nd4j.linspace(1, 10, 11, DataType.DOUBLE); INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java index ee57fd951..754bbe985 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionTests.java @@ -24,10 +24,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; -import org.nd4j.linalg.api.buffer.DataBuffer; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; @@ -47,16 +47,15 @@ import static junit.framework.TestCase.assertFalse; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class CompressionTests extends BaseNd4jTest { - public CompressionTests(Nd4jBackend backend) { - super(backend); - } +public class CompressionTests extends BaseNd4jTestWithBackends { + @Test - public void testCompressionDescriptorSerde() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCompressionDescriptorSerde(Nd4jBackend backend) { CompressionDescriptor descriptor = new CompressionDescriptor(); descriptor.setCompressedLength(4); descriptor.setOriginalElementSize(4); @@ -71,7 +70,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testGzipInPlaceCompression() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGzipInPlaceCompression(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); Nd4j.getCompressor().setDefaultCompression("GZIP"); Nd4j.getCompressor().compressi(array); @@ -81,7 +82,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testGzipCompression1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGzipCompression1(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 10000, 20000, DataType.FLOAT); INDArray exp = array.dup(); @@ -98,7 +101,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testNoOpCompression1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNoOpCompression1(Nd4jBackend backend) { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); INDArray array = Nd4j.linspace(1, 10000, 20000, DataType.FLOAT); INDArray exp = Nd4j.linspace(1, 10000, 20000, DataType.FLOAT); @@ -124,7 +129,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testJVMCompression3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJVMCompression3(Nd4jBackend backend) { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); INDArray exp = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}).reshape(1,-1); @@ -143,7 +150,9 @@ public class CompressionTests extends BaseNd4jTest { @Disabled @Test - public void testThresholdCompression0() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThresholdCompression0(Nd4jBackend backend) { INDArray initial = Nd4j.rand(new int[] {1, 150000000}, 119L); log.info("DTYPE: {}", Nd4j.dataType()); @@ -174,7 +183,9 @@ public class CompressionTests extends BaseNd4jTest { @Test @Disabled - public void testThresholdCompression1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThresholdCompression1(Nd4jBackend backend) { INDArray initial = Nd4j.create(new float[] {0.0f, 0.0f, 1e-3f, -1e-3f, 0.0f, 0.0f}); INDArray exp_0 = Nd4j.create(DataType.FLOAT, 6); INDArray exp_1 = initial.dup(); @@ -193,7 +204,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testThresholdCompression2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThresholdCompression2(Nd4jBackend backend) { INDArray initial = Nd4j.create(new double[] {1.0, 2.0, 0.0, 0.0, -1.0, -1.0}); INDArray exp_0 = Nd4j.create(new double[] {1.0 - 1e-3, 2.0 - 1e-3, 0.0, 0.0, -1.0 + 1e-3, -1.0 + 1e-3}); INDArray exp_1 = Nd4j.create(new double[] {1e-3, 1e-3, 0.0, 0.0, -1e-3, -1e-3}); @@ -215,7 +228,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testThresholdCompression3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThresholdCompression3(Nd4jBackend backend) { INDArray initial = Nd4j.create(new double[] {-1.0, -2.0, 0.0, 0.0, 1.0, 1.0}); INDArray exp_0 = Nd4j.create(new double[] {-1.0 + 1e-3, -2.0 + 1e-3, 0.0, 0.0, 1.0 - 1e-3, 1.0 - 1e-3}); INDArray exp_1 = Nd4j.create(new double[] {-1e-3, -1e-3, 0.0, 0.0, 1e-3, 1e-3}); @@ -244,7 +259,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testThresholdCompression4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThresholdCompression4(Nd4jBackend backend) { INDArray initial = Nd4j.create(new double[] {1e-4, -1e-4, 0.0, 0.0, 1e-4, -1e-4}); INDArray exp_0 = initial.dup(); @@ -262,7 +279,9 @@ public class CompressionTests extends BaseNd4jTest { @Test - public void testThresholdCompression5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThresholdCompression5(Nd4jBackend backend) { INDArray initial = Nd4j.ones(10); INDArray exp_0 = initial.dup(); @@ -279,7 +298,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testThresholdCompression5_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThresholdCompression5_1(Nd4jBackend backend) { INDArray initial = Nd4j.ones(1000); INDArray exp_0 = initial.dup(); @@ -296,7 +317,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testThresholdCompression6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThresholdCompression6(Nd4jBackend backend) { INDArray initial = Nd4j.create(new double[] {1.0, 2.0, 0.0, 0.0, -1.0, -1.0}); INDArray exp_0 = Nd4j.create(new double[] {1.0 - 1e-3, 2.0 - 1e-3, 0.0, 0.0, -1.0 + 1e-3, -1.0 + 1e-3}); INDArray exp_1 = Nd4j.create(new double[] {1e-3, 1e-3, 0.0, 0.0, -1e-3, -1e-3}); @@ -325,7 +348,9 @@ public class CompressionTests extends BaseNd4jTest { @Test - public void testThresholdSerialization1() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThresholdSerialization1(Nd4jBackend backend) throws Exception { INDArray initial = Nd4j.create(new double[] {-1.0, -2.0, 0.0, 0.0, 1.0, 1.0}); INDArray exp_0 = Nd4j.create(new double[] {-1.0 + 1e-3, -2.0 + 1e-3, 0.0, 0.0, 1.0 - 1e-3, 1.0 - 1e-3}); INDArray exp_1 = Nd4j.create(new double[] {-1e-3, -1e-3, 0.0, 0.0, 1e-3, 1e-3}); @@ -347,7 +372,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testBitmapEncoding1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBitmapEncoding1(Nd4jBackend backend) { INDArray initial = Nd4j.create(new float[] {0.0f, 0.0f, 1e-3f, -1e-3f, 0.0f, 0.0f}); INDArray exp_0 = Nd4j.create(DataType.FLOAT, 6); INDArray exp_1 = initial.dup(); @@ -369,7 +396,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testBitmapEncoding1_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBitmapEncoding1_1(Nd4jBackend backend) { INDArray initial = Nd4j.create(15); INDArray exp_0 = Nd4j.create(6); INDArray exp_1 = initial.dup(); @@ -393,7 +422,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testBitmapEncoding2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBitmapEncoding2(Nd4jBackend backend) { INDArray initial = Nd4j.create(40000000); INDArray target = Nd4j.create(initial.length()); @@ -413,7 +444,9 @@ public class CompressionTests extends BaseNd4jTest { @Test - public void testBitmapEncoding3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBitmapEncoding3(Nd4jBackend backend) { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); INDArray initial = Nd4j.create(new float[] {0.0f, -6e-4f, 1e-3f, -1e-3f, 0.0f, 0.0f}); INDArray exp_0 = Nd4j.create(new float[] {0.0f, -1e-4f, 0.0f, 0.0f, 0.0f, 0.0f}); @@ -440,7 +473,9 @@ public class CompressionTests extends BaseNd4jTest { @Test - public void testBitmapEncoding4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBitmapEncoding4(Nd4jBackend backend) { Nd4j.getRandom().setSeed(119); INDArray initial = Nd4j.rand(new int[]{1, 10000}, 0, 1, Nd4j.getRandom()); INDArray exp_1 = initial.dup(); @@ -453,7 +488,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testBitmapEncoding5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBitmapEncoding5(Nd4jBackend backend) { Nd4j.getRandom().setSeed(119); INDArray initial = Nd4j.rand(new int[]{10000}, -1, -0.5, Nd4j.getRandom()); INDArray exp_0 = initial.dup().addi(1e-1); @@ -468,7 +505,9 @@ public class CompressionTests extends BaseNd4jTest { } @Test - public void testBitmapEncoding6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBitmapEncoding6(Nd4jBackend backend) { Nd4j.getRandom().setSeed(119); INDArray initial = Nd4j.rand(new int[]{10000}, -1, 1, Nd4j.getRandom()); INDArray exp_1 = initial.dup(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java index 53bf93d3a..4491f485b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTests.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.convolution; import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.AllocUtil; @@ -48,16 +49,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.point; -@RunWith(Parameterized.class) -public class ConvolutionTests extends BaseNd4jTest { - - public ConvolutionTests(Nd4jBackend backend) { - super(backend); - } +public class ConvolutionTests extends BaseNd4jTestWithBackends { @Test - public void testIm2ColKnownValues() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2ColKnownValues(Nd4jBackend backend) { //Input: w=3, h=3, depth=2, minibatch = 2 //kH=2, kW=2 /* @@ -112,13 +110,13 @@ public class ConvolutionTests extends BaseNd4jTest { //Input data: shape [miniBatch,depth,height,width] INDArray input = Nd4j.create(new int[] {miniBatch, depth, height, width}, 'c'); input.put(new INDArrayIndex[] {point(0), point(0), all(), - all()}, Nd4j.create(new double[][] {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})); + all()}, Nd4j.create(new double[][] {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})); input.put(new INDArrayIndex[] {point(0), point(1), all(), - all()}, Nd4j.create(new double[][] {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}})); + all()}, Nd4j.create(new double[][] {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}})); input.put(new INDArrayIndex[] {point(1), point(0), all(), - all()}, Nd4j.create(new double[][] {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}})); + all()}, Nd4j.create(new double[][] {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}})); input.put(new INDArrayIndex[] {point(1), point(1), all(), - all()}, Nd4j.create(new double[][] {{27, 28, 29}, {30, 31, 32}, {33, 34, 35}})); + all()}, Nd4j.create(new double[][] {{27, 28, 29}, {30, 31, 32}, {33, 34, 35}})); //Expected data: INDArray expected = Nd4j.create(new int[] {miniBatch, depth, kH, kW, outH, outW}, 'c'); @@ -127,57 +125,57 @@ public class ConvolutionTests extends BaseNd4jTest { //depth 0 expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{0, 1}, {3, 4}})); + Nd4j.create(new double[][] {{0, 1}, {3, 4}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{1, 2}, {4, 5}})); + Nd4j.create(new double[][] {{1, 2}, {4, 5}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{3, 4}, {6, 7}})); + Nd4j.create(new double[][] {{3, 4}, {6, 7}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{4, 5}, {7, 8}})); + Nd4j.create(new double[][] {{4, 5}, {7, 8}})); //depth 1 expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{9, 10}, {12, 13}})); + Nd4j.create(new double[][] {{9, 10}, {12, 13}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{10, 11}, {13, 14}})); + Nd4j.create(new double[][] {{10, 11}, {13, 14}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{12, 13}, {15, 16}})); + Nd4j.create(new double[][] {{12, 13}, {15, 16}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{13, 14}, {16, 17}})); + Nd4j.create(new double[][] {{13, 14}, {16, 17}})); //Example 1 //depth 0 expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{18, 19}, {21, 22}})); + Nd4j.create(new double[][] {{18, 19}, {21, 22}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{19, 20}, {22, 23}})); + Nd4j.create(new double[][] {{19, 20}, {22, 23}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{21, 22}, {24, 25}})); + Nd4j.create(new double[][] {{21, 22}, {24, 25}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{22, 23}, {25, 26}})); + Nd4j.create(new double[][] {{22, 23}, {25, 26}})); //depth 1 expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{27, 28}, {30, 31}})); + Nd4j.create(new double[][] {{27, 28}, {30, 31}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{28, 29}, {31, 32}})); + Nd4j.create(new double[][] {{28, 29}, {31, 32}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{30, 31}, {33, 34}})); + Nd4j.create(new double[][] {{30, 31}, {33, 34}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{31, 32}, {34, 35}})); + Nd4j.create(new double[][] {{31, 32}, {34, 35}})); INDArray out = Convolution.im2col(input, kH, kW, sY, sX, pY, pX, false); assertEquals(expected, out); @@ -196,7 +194,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testIm2ColKnownValuesDilated() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2ColKnownValuesDilated(Nd4jBackend backend) { //Input: w=4, h=4, depth=1, minibatch = 2, dilation=2, stride 1 //kH=2, kW=2 /* @@ -309,7 +309,9 @@ public class ConvolutionTests extends BaseNd4jTest { } @Test - public void testIm2ColKnownValuesDilatedStrided() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2ColKnownValuesDilatedStrided(Nd4jBackend backend) { //Input: w=5, h=5, depth=1, minibatch = 1, dilation=2, stride 2 //kH=2, kW=2 /* @@ -391,7 +393,9 @@ public class ConvolutionTests extends BaseNd4jTest { } @Test - public void testIm2ColKnownValuesMiniBatch3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2ColKnownValuesMiniBatch3(Nd4jBackend backend) { //Input: w=3, h=3, depth=2, minibatch = 3 //kH=2, kW=2 /* @@ -461,17 +465,17 @@ public class ConvolutionTests extends BaseNd4jTest { //Input data: shape [miniBatch,depth,height,width] INDArray input = Nd4j.create(new int[] {miniBatch, depth, height, width}, 'c'); input.put(new INDArrayIndex[] {point(0), point(0), all(), - all()}, Nd4j.create(new double[][] {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})); + all()}, Nd4j.create(new double[][] {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})); input.put(new INDArrayIndex[] {point(0), point(1), all(), - all()}, Nd4j.create(new double[][] {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}})); + all()}, Nd4j.create(new double[][] {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}})); input.put(new INDArrayIndex[] {point(1), point(0), all(), - all()}, Nd4j.create(new double[][] {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}})); + all()}, Nd4j.create(new double[][] {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}})); input.put(new INDArrayIndex[] {point(1), point(1), all(), - all()}, Nd4j.create(new double[][] {{27, 28, 29}, {30, 31, 32}, {33, 34, 35}})); + all()}, Nd4j.create(new double[][] {{27, 28, 29}, {30, 31, 32}, {33, 34, 35}})); input.put(new INDArrayIndex[] {point(2), point(0), all(), - all()}, Nd4j.create(new double[][] {{36, 37, 38}, {39, 40, 41}, {42, 43, 44}})); + all()}, Nd4j.create(new double[][] {{36, 37, 38}, {39, 40, 41}, {42, 43, 44}})); input.put(new INDArrayIndex[] {point(2), point(1), all(), - all()}, Nd4j.create(new double[][] {{45, 46, 47}, {48, 49, 50}, {51, 52, 53}})); + all()}, Nd4j.create(new double[][] {{45, 46, 47}, {48, 49, 50}, {51, 52, 53}})); //Expected data: INDArray expected = Nd4j.create(new int[] {miniBatch, depth, kH, kW, outH, outW}, 'c'); @@ -480,85 +484,85 @@ public class ConvolutionTests extends BaseNd4jTest { //depth 0 expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{0, 1}, {3, 4}})); + Nd4j.create(new double[][] {{0, 1}, {3, 4}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{1, 2}, {4, 5}})); + Nd4j.create(new double[][] {{1, 2}, {4, 5}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{3, 4}, {6, 7}})); + Nd4j.create(new double[][] {{3, 4}, {6, 7}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{4, 5}, {7, 8}})); + Nd4j.create(new double[][] {{4, 5}, {7, 8}})); //depth 1 expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{9, 10}, {12, 13}})); + Nd4j.create(new double[][] {{9, 10}, {12, 13}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{10, 11}, {13, 14}})); + Nd4j.create(new double[][] {{10, 11}, {13, 14}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{12, 13}, {15, 16}})); + Nd4j.create(new double[][] {{12, 13}, {15, 16}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{13, 14}, {16, 17}})); + Nd4j.create(new double[][] {{13, 14}, {16, 17}})); //Example 1 //depth 0 expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{18, 19}, {21, 22}})); + Nd4j.create(new double[][] {{18, 19}, {21, 22}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{19, 20}, {22, 23}})); + Nd4j.create(new double[][] {{19, 20}, {22, 23}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{21, 22}, {24, 25}})); + Nd4j.create(new double[][] {{21, 22}, {24, 25}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{22, 23}, {25, 26}})); + Nd4j.create(new double[][] {{22, 23}, {25, 26}})); //depth 1 expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{27, 28}, {30, 31}})); + Nd4j.create(new double[][] {{27, 28}, {30, 31}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{28, 29}, {31, 32}})); + Nd4j.create(new double[][] {{28, 29}, {31, 32}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{30, 31}, {33, 34}})); + Nd4j.create(new double[][] {{30, 31}, {33, 34}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{31, 32}, {34, 35}})); + Nd4j.create(new double[][] {{31, 32}, {34, 35}})); //Example 2 //depth 0 expected.put(new INDArrayIndex[] {point(2), point(0), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{36, 37}, {39, 40}})); + Nd4j.create(new double[][] {{36, 37}, {39, 40}})); expected.put(new INDArrayIndex[] {point(2), point(0), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{37, 38}, {40, 41}})); + Nd4j.create(new double[][] {{37, 38}, {40, 41}})); expected.put(new INDArrayIndex[] {point(2), point(0), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{39, 40}, {42, 43}})); + Nd4j.create(new double[][] {{39, 40}, {42, 43}})); expected.put(new INDArrayIndex[] {point(2), point(0), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{40, 41}, {43, 44}})); + Nd4j.create(new double[][] {{40, 41}, {43, 44}})); //depth 1 expected.put(new INDArrayIndex[] {point(2), point(1), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{45, 46}, {48, 49}})); + Nd4j.create(new double[][] {{45, 46}, {48, 49}})); expected.put(new INDArrayIndex[] {point(2), point(1), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{46, 47}, {49, 50}})); + Nd4j.create(new double[][] {{46, 47}, {49, 50}})); expected.put(new INDArrayIndex[] {point(2), point(1), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{48, 49}, {51, 52}})); + Nd4j.create(new double[][] {{48, 49}, {51, 52}})); expected.put(new INDArrayIndex[] {point(2), point(1), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{49, 50}, {52, 53}})); + Nd4j.create(new double[][] {{49, 50}, {52, 53}})); INDArray out = Convolution.im2col(input, kH, kW, sY, sX, pY, pX, false); assertEquals(expected, out); @@ -577,7 +581,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testIm2ColSamePadding() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2ColSamePadding(Nd4jBackend backend) { //Input: w=3, h=3, depth=2, minibatch = 2, kH/kW = 2, stride=1 //Idea with same padding: @@ -659,13 +665,13 @@ public class ConvolutionTests extends BaseNd4jTest { //Input data: shape [miniBatch,depth,height,width] INDArray input = Nd4j.create(new int[] {miniBatch, depth, inH, inW}, 'c'); input.put(new INDArrayIndex[] {point(0), point(0), all(), - all()}, Nd4j.create(new double[][] {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})); + all()}, Nd4j.create(new double[][] {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})); input.put(new INDArrayIndex[] {point(0), point(1), all(), - all()}, Nd4j.create(new double[][] {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}})); + all()}, Nd4j.create(new double[][] {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}})); input.put(new INDArrayIndex[] {point(1), point(0), all(), - all()}, Nd4j.create(new double[][] {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}})); + all()}, Nd4j.create(new double[][] {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}})); input.put(new INDArrayIndex[] {point(1), point(1), all(), - all()}, Nd4j.create(new double[][] {{27, 28, 29}, {30, 31, 32}, {33, 34, 35}})); + all()}, Nd4j.create(new double[][] {{27, 28, 29}, {30, 31, 32}, {33, 34, 35}})); //Expected data: INDArray expected = Nd4j.create(new int[] {miniBatch, depth, kH, kW, outH, outW}, 'c'); @@ -674,118 +680,118 @@ public class ConvolutionTests extends BaseNd4jTest { //depth 0 expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{0, 1}, {3, 4}})); + Nd4j.create(new double[][] {{0, 1}, {3, 4}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{1, 2}, {4, 5}})); + Nd4j.create(new double[][] {{1, 2}, {4, 5}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(0), point(2)}, - Nd4j.create(new double[][] {{2, 0}, {5, 0}})); + Nd4j.create(new double[][] {{2, 0}, {5, 0}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{3, 4}, {6, 7}})); + Nd4j.create(new double[][] {{3, 4}, {6, 7}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{4, 5}, {7, 8}})); + Nd4j.create(new double[][] {{4, 5}, {7, 8}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(1), point(2)}, - Nd4j.create(new double[][] {{5, 0}, {8, 0}})); + Nd4j.create(new double[][] {{5, 0}, {8, 0}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(2), point(0)}, - Nd4j.create(new double[][] {{6, 7}, {0, 0}})); + Nd4j.create(new double[][] {{6, 7}, {0, 0}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(2), point(1)}, - Nd4j.create(new double[][] {{7, 8}, {0, 0}})); + Nd4j.create(new double[][] {{7, 8}, {0, 0}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(2), point(2)}, - Nd4j.create(new double[][] {{8, 0}, {0, 0}})); + Nd4j.create(new double[][] {{8, 0}, {0, 0}})); //depth 1 expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{9, 10}, {12, 13}})); + Nd4j.create(new double[][] {{9, 10}, {12, 13}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{10, 11}, {13, 14}})); + Nd4j.create(new double[][] {{10, 11}, {13, 14}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(0), point(2)}, - Nd4j.create(new double[][] {{11, 0}, {14, 0}})); + Nd4j.create(new double[][] {{11, 0}, {14, 0}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{12, 13}, {15, 16}})); + Nd4j.create(new double[][] {{12, 13}, {15, 16}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{13, 14}, {16, 17}})); + Nd4j.create(new double[][] {{13, 14}, {16, 17}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(1), point(2)}, - Nd4j.create(new double[][] {{14, 0}, {17, 0}})); + Nd4j.create(new double[][] {{14, 0}, {17, 0}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(2), point(0)}, - Nd4j.create(new double[][] {{15, 16}, {0, 0}})); + Nd4j.create(new double[][] {{15, 16}, {0, 0}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(2), point(1)}, - Nd4j.create(new double[][] {{16, 17}, {0, 0}})); + Nd4j.create(new double[][] {{16, 17}, {0, 0}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(2), point(2)}, - Nd4j.create(new double[][] {{17, 0}, {0, 0}})); + Nd4j.create(new double[][] {{17, 0}, {0, 0}})); //Example 1 //depth 0 expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{18, 19}, {21, 22}})); + Nd4j.create(new double[][] {{18, 19}, {21, 22}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{19, 20}, {22, 23}})); + Nd4j.create(new double[][] {{19, 20}, {22, 23}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(0), point(2)}, - Nd4j.create(new double[][] {{20, 0}, {23, 0}})); + Nd4j.create(new double[][] {{20, 0}, {23, 0}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{21, 22}, {24, 25}})); + Nd4j.create(new double[][] {{21, 22}, {24, 25}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{22, 23}, {25, 26}})); + Nd4j.create(new double[][] {{22, 23}, {25, 26}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(1), point(2)}, - Nd4j.create(new double[][] {{23, 0}, {26, 0}})); + Nd4j.create(new double[][] {{23, 0}, {26, 0}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(2), point(0)}, - Nd4j.create(new double[][] {{24, 25}, {0, 0}})); + Nd4j.create(new double[][] {{24, 25}, {0, 0}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(2), point(1)}, - Nd4j.create(new double[][] {{25, 26}, {0, 0}})); + Nd4j.create(new double[][] {{25, 26}, {0, 0}})); expected.put(new INDArrayIndex[] {point(1), point(0), all(), all(), point(2), point(2)}, - Nd4j.create(new double[][] {{26, 0}, {0, 0}})); + Nd4j.create(new double[][] {{26, 0}, {0, 0}})); //depth 1 expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{27, 28}, {30, 31}})); + Nd4j.create(new double[][] {{27, 28}, {30, 31}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{28, 29}, {31, 32}})); + Nd4j.create(new double[][] {{28, 29}, {31, 32}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(0), point(2)}, - Nd4j.create(new double[][] {{29, 0}, {32, 0}})); + Nd4j.create(new double[][] {{29, 0}, {32, 0}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{30, 31}, {33, 34}})); + Nd4j.create(new double[][] {{30, 31}, {33, 34}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{31, 32}, {34, 35}})); + Nd4j.create(new double[][] {{31, 32}, {34, 35}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(1), point(2)}, - Nd4j.create(new double[][] {{32, 0}, {35, 0}})); + Nd4j.create(new double[][] {{32, 0}, {35, 0}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(2), point(0)}, - Nd4j.create(new double[][] {{33, 34}, {0, 0}})); + Nd4j.create(new double[][] {{33, 34}, {0, 0}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(2), point(1)}, - Nd4j.create(new double[][] {{34, 35}, {0, 0}})); + Nd4j.create(new double[][] {{34, 35}, {0, 0}})); expected.put(new INDArrayIndex[] {point(1), point(1), all(), all(), point(2), point(2)}, - Nd4j.create(new double[][] {{35, 0}, {0, 0}})); + Nd4j.create(new double[][] {{35, 0}, {0, 0}})); //[miniBatch,depth,kH,kW,outH,outW] INDArray outAlloc = Nd4j.create(miniBatch, depth, kH, kW, outH, outW); @@ -836,7 +842,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testIm2ColSamePaddingStride2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2ColSamePaddingStride2(Nd4jBackend backend) { //Input: h=3, w=4, depth=2, minibatch = 1, kH/kW = 3, stride=2 //Idea with same padding: @@ -904,10 +912,10 @@ public class ConvolutionTests extends BaseNd4jTest { //Input data: shape [miniBatch,depth,height,width] INDArray input = Nd4j.create(new int[] {miniBatch, depth, inH, inW}, 'c'); input.put(new INDArrayIndex[] {point(0), point(0), all(), - all()}, Nd4j.create(new double[][] {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})); + all()}, Nd4j.create(new double[][] {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})); input.put(new INDArrayIndex[] {point(0), point(1), all(), all()}, - Nd4j.create(new double[][] {{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}})); + Nd4j.create(new double[][] {{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}})); //Expected data: INDArray expected = Nd4j.create(new int[] {miniBatch, depth, kH, kW, outH, outW}, 'c'); @@ -916,29 +924,29 @@ public class ConvolutionTests extends BaseNd4jTest { //depth 0 expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{0, 0, 0}, {0, 1, 2}, {4, 5, 6}})); + Nd4j.create(new double[][] {{0, 0, 0}, {0, 1, 2}, {4, 5, 6}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{0, 0, 0}, {2, 3, 0}, {6, 7, 0}})); + Nd4j.create(new double[][] {{0, 0, 0}, {2, 3, 0}, {6, 7, 0}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{4, 5, 6}, {8, 9, 10}, {0, 0, 0}})); + Nd4j.create(new double[][] {{4, 5, 6}, {8, 9, 10}, {0, 0, 0}})); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{6, 7, 0}, {10, 11, 0}, {0, 0, 0}})); + Nd4j.create(new double[][] {{6, 7, 0}, {10, 11, 0}, {0, 0, 0}})); //depth 1 expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{0, 0, 0}, {12, 13, 14}, {16, 17, 18}})); + Nd4j.create(new double[][] {{0, 0, 0}, {12, 13, 14}, {16, 17, 18}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{0, 0, 0}, {14, 15, 0}, {18, 19, 0}})); + Nd4j.create(new double[][] {{0, 0, 0}, {14, 15, 0}, {18, 19, 0}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{16, 17, 18}, {20, 21, 22}, {0, 0, 0}})); + Nd4j.create(new double[][] {{16, 17, 18}, {20, 21, 22}, {0, 0, 0}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{18, 19, 0}, {22, 23, 0}, {0, 0, 0}})); + Nd4j.create(new double[][] {{18, 19, 0}, {22, 23, 0}, {0, 0, 0}})); //[miniBatch,depth,kH,kW,outH,outW] INDArray outAlloc = Nd4j.create(miniBatch, depth, kH, kW, outH, outW); @@ -989,7 +997,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testCol2ImSamePaddingStride2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCol2ImSamePaddingStride2(Nd4jBackend backend) { //Input: h=3, w=4, depth=2, minibatch = 1, kH/kW = 3, stride=2 //Idea with same padding: @@ -1075,39 +1085,39 @@ public class ConvolutionTests extends BaseNd4jTest { //depth 0 col6d.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{0, 0, 0}, {0, 1, 2}, {4, 5, 6}})); + Nd4j.create(new double[][] {{0, 0, 0}, {0, 1, 2}, {4, 5, 6}})); col6d.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{0, 0, 0}, {2, 3, 0}, {6, 7, 0}})); + Nd4j.create(new double[][] {{0, 0, 0}, {2, 3, 0}, {6, 7, 0}})); col6d.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{4, 5, 6}, {8, 9, 10}, {0, 0, 0}})); + Nd4j.create(new double[][] {{4, 5, 6}, {8, 9, 10}, {0, 0, 0}})); col6d.put(new INDArrayIndex[] {point(0), point(0), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{6, 7, 0}, {10, 11, 0}, {0, 0, 0}})); + Nd4j.create(new double[][] {{6, 7, 0}, {10, 11, 0}, {0, 0, 0}})); //depth 1 col6d.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(0), point(0)}, - Nd4j.create(new double[][] {{0, 0, 0}, {12, 13, 14}, {16, 17, 18}})); + Nd4j.create(new double[][] {{0, 0, 0}, {12, 13, 14}, {16, 17, 18}})); col6d.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(0), point(1)}, - Nd4j.create(new double[][] {{0, 0, 0}, {14, 15, 0}, {18, 19, 0}})); + Nd4j.create(new double[][] {{0, 0, 0}, {14, 15, 0}, {18, 19, 0}})); col6d.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(1), point(0)}, - Nd4j.create(new double[][] {{16, 17, 18}, {20, 21, 22}, {0, 0, 0}})); + Nd4j.create(new double[][] {{16, 17, 18}, {20, 21, 22}, {0, 0, 0}})); col6d.put(new INDArrayIndex[] {point(0), point(1), all(), all(), point(1), point(1)}, - Nd4j.create(new double[][] {{18, 19, 0}, {22, 23, 0}, {0, 0, 0}})); + Nd4j.create(new double[][] {{18, 19, 0}, {22, 23, 0}, {0, 0, 0}})); //Expected result: INDArray expected = Nd4j.create(miniBatch, depth, inH, inW); expected.put(new INDArrayIndex[] {point(0), point(0), all(), all()}, - Nd4j.create(new double[][] {{0, 1, 4, 3}, {8, 10, 24, 14}, {8, 9, 20, 11}})); + Nd4j.create(new double[][] {{0, 1, 4, 3}, {8, 10, 24, 14}, {8, 9, 20, 11}})); expected.put(new INDArrayIndex[] {point(0), point(1), all(), all()}, - Nd4j.create(new double[][] {{12, 13, 28, 15}, {32, 34, 72, 38}, {20, 21, 44, 23}})); + Nd4j.create(new double[][] {{12, 13, 28, 15}, {32, 34, 72, 38}, {20, 21, 44, 23}})); INDArray col2imResult = Nd4j.create(miniBatch, depth, inH, inW); @@ -1118,7 +1128,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testCol2ImSamePaddingStride1Dilation2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCol2ImSamePaddingStride1Dilation2(Nd4jBackend backend) { //Input: h=4, w=5, depth=1, minibatch = 1, kH/kW = 2, stride=1, dilation 2 //Idea with same padding: @@ -1305,13 +1317,17 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testConvOutWidthAndHeight() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConvOutWidthAndHeight(Nd4jBackend backend) { long outSize = Convolution.outSize(2, 1, 1, 2, 1, false); assertEquals(6, outSize); } /* - @Test - public void testIm2Col() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2Col(Nd4jBackend backend) { INDArray linspaced = Nd4j.linspace(1, 16, 16, DataType.FLOAT).reshape(2, 2, 2, 2); INDArray ret = Convolution.im2col(linspaced, 1, 1, 1, 1, 2, 2, 0, false); System.out.println(ret); @@ -1322,7 +1338,7 @@ public class ConvolutionTests extends BaseNd4jTest { @Test @Disabled - public void testCompareIm2ColImpl() { + public void testCompareIm2ColImpl(Nd4jBackend backend) { int[] miniBatches = {1, 3, 5}; int[] depths = {1, 3, 5}; @@ -1337,17 +1353,17 @@ public class ConvolutionTests extends BaseNd4jTest { boolean[] coverall = {false, true}; DataType[] types = new DataType[] {DataType.FLOAT, DataType.FLOAT, - DataType.FLOAT, DataType.FLOAT}; + DataType.FLOAT, DataType.FLOAT}; DataBuffer.AllocationMode[] modes = - new DataBuffer.AllocationMode[] {DataBuffer.AllocationMode.HEAP, DataBuffer.AllocationMode.HEAP, - DataBuffer.AllocationMode.DIRECT, DataBuffer.AllocationMode.DIRECT}; + new DataBuffer.AllocationMode[] {DataBuffer.AllocationMode.HEAP, DataBuffer.AllocationMode.HEAP, + DataBuffer.AllocationMode.DIRECT, DataBuffer.AllocationMode.DIRECT}; String factoryClassName = Nd4j.factory().getClass().toString().toLowerCase(); if (factoryClassName.contains("jcublas") || factoryClassName.contains("cuda")) { //Only test direct for CUDA; test all for CPU types = new DataType[] {DataType.FLOAT, DataType.FLOAT}; modes = new DataBuffer.AllocationMode[] {DataBuffer.AllocationMode.DIRECT, - DataBuffer.AllocationMode.DIRECT}; + DataBuffer.AllocationMode.DIRECT}; } DataType initialType = Nd4j.dataType(); @@ -1381,12 +1397,12 @@ public class ConvolutionTests extends BaseNd4jTest { //assertEquals(in.data().dataType(), opType); INDArray outOrig = OldConvolution.im2col(in, kh, kw, sh, sw, ph, - pw, -1, cAll); //Old implementation + pw, -1, cAll); //Old implementation INDArray outNew = Convolution.im2col(in, kh, kw, sh, sw, ph, pw, - cAll); //Current implementation + cAll); //Current implementation assertArrayEquals(outOrig.data().asFloat(), - outNew.data().asFloat(), 0.01f); + outNew.data().asFloat(), 0.01f); assertEquals(outOrig, outNew); } } @@ -1406,7 +1422,7 @@ public class ConvolutionTests extends BaseNd4jTest { @Test @Disabled - public void testCompareIm2Col() { + public void testCompareIm2Col(Nd4jBackend backend) { int[] miniBatches = {1, 3, 5}; int[] depths = {1, 3, 5}; @@ -1420,17 +1436,17 @@ public class ConvolutionTests extends BaseNd4jTest { int[] padW = {0, 1, 2}; DataType[] types = new DataType[] {DataType.FLOAT, DataType.FLOAT, - DataType.FLOAT, DataType.FLOAT}; + DataType.FLOAT, DataType.FLOAT}; DataBuffer.AllocationMode[] modes = - new DataBuffer.AllocationMode[] {DataBuffer.AllocationMode.HEAP, DataBuffer.AllocationMode.HEAP, - DataBuffer.AllocationMode.DIRECT, DataBuffer.AllocationMode.DIRECT}; + new DataBuffer.AllocationMode[] {DataBuffer.AllocationMode.HEAP, DataBuffer.AllocationMode.HEAP, + DataBuffer.AllocationMode.DIRECT, DataBuffer.AllocationMode.DIRECT}; String factoryClassName = Nd4j.factory().getClass().toString().toLowerCase(); if (factoryClassName.contains("jcublas") || factoryClassName.contains("cuda")) { //Only test direct for CUDA; test all for CPU types = new DataType[] {DataType.FLOAT, DataType.FLOAT}; modes = new DataBuffer.AllocationMode[] {DataBuffer.AllocationMode.DIRECT, - DataBuffer.AllocationMode.DIRECT}; + DataBuffer.AllocationMode.DIRECT}; } DataType inititalType = Nd4j.dataType(); @@ -1459,12 +1475,12 @@ public class ConvolutionTests extends BaseNd4jTest { assertEquals(in.data().allocationMode(), mode); assertEquals(in.data().dataType(), type); INDArray im2col = Convolution.im2col(in, kh, kw, sh, sw, ph, pw, - false); //Cheating, to get correct shape for input + false); //Cheating, to get correct shape for input INDArray imgOutOld = - OldConvolution.col2im(im2col, sh, sw, ph, pw, h, w); + OldConvolution.col2im(im2col, sh, sw, ph, pw, h, w); INDArray imgOutNew = - Convolution.col2im(im2col, sh, sw, ph, pw, h, w); + Convolution.col2im(im2col, sh, sw, ph, pw, h, w); System.out.println("F order test"); System.out.println(imgOutOld); System.out.println(imgOutNew); @@ -1486,7 +1502,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testCol2Im() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCol2Im(Nd4jBackend backend) { int kh = 1; int kw = 1; int sy = 1; @@ -1505,7 +1523,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testimcolim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testimcolim(Nd4jBackend backend) { int nEx = 2; int depth = 3; int width = 7; @@ -1527,7 +1547,9 @@ public class ConvolutionTests extends BaseNd4jTest { } @Test - public void testIm2ColWithDilation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2ColWithDilation(Nd4jBackend backend) { int kH = 2; int kW = 2; int sH = 1; @@ -1571,6 +1593,8 @@ public class ConvolutionTests extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPoolingEdgeCases(){ //Average pooling with same mode: should we include the padded values, when deciding what to divide by? ///*** Note: Mode 2 is the "DL4J always divide by kH*kW" approach *** @@ -1655,7 +1679,9 @@ public class ConvolutionTests extends BaseNd4jTest { } @Test - public void testPooling1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling1(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}, new int[]{2, 2, 2, 2}, 'c'); @@ -1717,7 +1743,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testPooling2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling2(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}, new int[]{2, 2, 2, 2}, 'c'); @@ -1739,7 +1767,9 @@ public class ConvolutionTests extends BaseNd4jTest { } @Test - public void testPooling3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling3(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f}, new int[]{2, 2, 2, 2}, 'c'); @@ -1762,7 +1792,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testPooling4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling4(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f}, new int[]{2, 2, 2, 2}, 'c'); @@ -1785,7 +1817,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testPooling5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling5(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{7.f, 8.f, 11.f, 12.f, 14.f, 15.f, 27.f, 28.f, 31.f, 32.f, 34.f, 35.f, 42.f, 43.f, 46.f, 47.f, 49.f, 50.f, 57.f, 58.f, 61.f, 62.f, 64.f, 65.f, 77.f, 78.f, 81.f, 82.f, 84.f, 85.f, 92.f, 93.f, 96.f, 97.f, 99.f, 100.f}, new int[]{2, 3, 3, 2}, 'c'); @@ -1808,7 +1842,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testPooling6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling6(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{7.f, 8.f, 11.f, 12.f, 27.f, 28.f, 31.f, 32.f, 57.f, 58.f, 61.f, 62.f, 77.f, 78.f, 81.f, 82.f}, new int[]{2, 2, 2, 2}, 'c'); @@ -1831,7 +1867,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testPooling7() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling7(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{7.f, 9.f, 17.f, 19.f, 32.f, 34.f, 42.f, 44.f, 57.f, 59.f, 67.f, 69.f, 82.f, 84.f, 92.f, 94.f}, new int[]{2, 2, 2, 2}, 'c'); @@ -1853,7 +1891,9 @@ public class ConvolutionTests extends BaseNd4jTest { } @Test - public void testPooling8() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling8(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{1.f, 2.5f, 4.5f, 8.5f, 10.f, 12.f, 18.5f, 20.f, 22.f, 26.f, 27.5f, 29.5f, 33.5f, 35.f, 37.f, 43.5f, 45.f, 47.f, 51.f, 52.5f, 54.5f, 58.5f, 60.f, 62.f, 68.5f, 70.f, 72.f, 76.f, 77.5f, 79.5f, 83.5f, 85.f, 87.f, 93.5f, 95.f, 97.f}, new int[]{2, 2, 3, 3}, 'c'); @@ -1875,7 +1915,9 @@ public class ConvolutionTests extends BaseNd4jTest { } @Test - public void testPooling9() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling9(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{0.25f, 1.25f, 2.25f, 4.25f, 10.f, 12.f, 9.25f, 20.f, 22.f, 6.5f, 13.75f, 14.75f, 16.75f, 35.f, 37.f, 21.75f, 45.f, 47.f, 12.75f, 26.25f, 27.25f, 29.25f, 60.f, 62.f, 34.25f, 70.f, 72.f, 19.f, 38.75f, 39.75f, 41.75f, 85.f, 87.f, 46.75f, 95.f, 97.f}, new int[]{2, 2, 3, 3}, 'c'); @@ -1897,7 +1939,9 @@ public class ConvolutionTests extends BaseNd4jTest { } @Test - public void testPooling10() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling10(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{4.f, 6.f, 7.5f, 14.f, 16.f, 17.5f, 21.5f, 23.5f, 25.f, 29.f, 31.f, 32.5f, 39.f, 41.f, 42.5f, 46.5f, 48.5f, 50.f, 54.f, 56.f, 57.5f, 64.f, 66.f, 67.5f, 71.5f, 73.5f, 75.f, 79.f, 81.f, 82.5f, 89.f, 91.f, 92.5f, 96.5f, 98.5f, 100.f}, new int[]{2, 2, 3, 3}, 'c'); @@ -1919,7 +1963,9 @@ public class ConvolutionTests extends BaseNd4jTest { } @Test - public void testPooling11() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling11(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{3, 4, 6, 7}, new int[]{1, 1, 2, 2}, 'c'); @@ -1941,7 +1987,9 @@ public class ConvolutionTests extends BaseNd4jTest { } @Test - public void testPooling12() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling12(Nd4jBackend backend) { for( char outputOrder : new char[]{'c', 'f'}) { INDArray exp = Nd4j.create(new float[]{3.f, 4.f, 4.5f, 6.f, 7.f, 7.5f, 7.5f, 8.5f, 9.f}, new int[]{1, 1, 3, 3}, 'c'); @@ -1964,7 +2012,9 @@ public class ConvolutionTests extends BaseNd4jTest { @Test - public void testPooling13() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling13(Nd4jBackend backend) { for( char outputOrder : new char[]{'c'}) { INDArray exp = Nd4j.create(new float[]{3.f, 4.f, 4.5f, 6.f, 7.f, 7.5f, 7.5f, 8.5f, 9.f}, new int[]{1, 1, 3, 3}, 'c'); @@ -1988,6 +2038,8 @@ public class ConvolutionTests extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPoolingDilation(){ int[] inputShape = {1, 1, 4, 5}; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java index f48acb810..4278849e4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/ConvolutionTestsC.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.convolution; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.AllocUtil; @@ -46,22 +47,23 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@RunWith(Parameterized.class) -public class ConvolutionTestsC extends BaseNd4jTest { - public ConvolutionTestsC(Nd4jBackend backend) { - super(backend); - } +public class ConvolutionTestsC extends BaseNd4jTestWithBackends { + @Test - public void testConvOutWidthAndHeight() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConvOutWidthAndHeight(Nd4jBackend backend) { long outSize = Convolution.outSize(2, 1, 1, 2, 1, false); assertEquals(6, outSize); } @Test - public void testIm2Col() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2Col(Nd4jBackend backend) { INDArray linspaced = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); INDArray ret = Convolution.im2col(linspaced, 1, 1, 1, 1, 2, 2, 0, false); INDArray im2colAssertion = Nd4j.create(new double[] {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, @@ -85,7 +87,9 @@ public class ConvolutionTestsC extends BaseNd4jTest { } @Test - public void testIm2Col2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIm2Col2(Nd4jBackend backend) { int kh = 2; int kw = 2; int ph = 0; @@ -107,7 +111,9 @@ public class ConvolutionTestsC extends BaseNd4jTest { @Test @Disabled - public void testCompareIm2ColImpl() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCompareIm2ColImpl(Nd4jBackend backend) { int[] miniBatches = {1, 3, 5}; int[] depths = {1, 3, 5}; @@ -188,7 +194,9 @@ public class ConvolutionTestsC extends BaseNd4jTest { } @Test - public void testPooling2D_Same() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPooling2D_Same(Nd4jBackend backend) { int[] miniBatches = {1, 3, 5}; int[] depths = {1, 3, 5}; int[] inHeights = {5, 21}; @@ -249,7 +257,7 @@ public class ConvolutionTestsC extends BaseNd4jTest { Convolution.pooling2D(in, kh, kw, sh, sw, padTop, padLeft, 1, 1, true, Pooling2D.Pooling2DType.PNORM, Pooling2D.Divisor.INCLUDE_PADDING, - (double) pnorm, outSize[0], outSize[1], output); + pnorm, outSize[0], outSize[1], output); break; case MAX: @@ -284,7 +292,9 @@ public class ConvolutionTestsC extends BaseNd4jTest { } @Test - public void testMoreIm2Col2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMoreIm2Col2(Nd4jBackend backend) { int kh = 2; int kw = 2; int ph = 0; @@ -306,7 +316,9 @@ public class ConvolutionTestsC extends BaseNd4jTest { @Test - public void testCol2Im() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCol2Im(Nd4jBackend backend) { int kh = 1; int kw = 1; int sy = 1; @@ -322,7 +334,9 @@ public class ConvolutionTestsC extends BaseNd4jTest { } @Test - public void testimcolim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testimcolim(Nd4jBackend backend) { int nEx = 2; int depth = 3; int width = 7; @@ -346,6 +360,8 @@ public class ConvolutionTestsC extends BaseNd4jTest { @Test @Disabled + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMaxPoolBackprop(){ Nd4j.getRandom().setSeed(12345); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java index f88ee0cc1..8886d89de 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java @@ -27,9 +27,10 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.common.io.ClassPathResource; -import org.nd4j.common.resources.Resources; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; @@ -45,11 +46,8 @@ import java.util.HashSet; import java.util.List; import java.util.Set; -public class DeconvTests extends BaseNd4jTest { +public class DeconvTests extends BaseNd4jTestWithBackends { - public DeconvTests(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -57,7 +55,9 @@ public class DeconvTests extends BaseNd4jTest { } @Test - public void compareKeras(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void compareKeras(@TempDir Path testDir,Nd4jBackend backend) throws Exception { File newFolder = testDir.toFile(); new ClassPathResource("keras/deconv/").copyDirectory(newFolder); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java index 5736a5577..503f95fa2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java @@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; @@ -40,12 +41,9 @@ import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; @Slf4j -@RunWith(Parameterized.class) + @Disabled -public class CrashTest extends BaseNd4jTest { - public CrashTest(Nd4jBackend backend) { - super(backend); - } +public class CrashTest extends BaseNd4jTestWithBackends { private static final int ITERATIONS = 10; private static final boolean[] paramsA = new boolean[] {true, false}; @@ -56,7 +54,9 @@ public class CrashTest extends BaseNd4jTest { * tensorAlongDimension() produces shapeInfo without EWS defined */ @Test - public void testNonEWSViews1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNonEWSViews1(Nd4jBackend backend) { log.debug("non-EWS 1"); INDArray x = Nd4j.create(64, 1024, 64); INDArray y = Nd4j.create(64, 64, 1024); @@ -68,7 +68,9 @@ public class CrashTest extends BaseNd4jTest { } @Test - public void testNonEWSViews2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNonEWSViews2(Nd4jBackend backend) { log.debug("non-EWS 2"); INDArray x = Nd4j.create(new int[] {64, 1024, 64}, 'f'); INDArray y = Nd4j.create(new int[] {64, 64, 1024}, 'f'); @@ -83,7 +85,9 @@ public class CrashTest extends BaseNd4jTest { * slice() produces shapeInfo with EWS being 1 in our case */ @Test - public void testEWSViews1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEWSViews1(Nd4jBackend backend) { log.debug("EWS 1"); INDArray x = Nd4j.create(64, 1024, 64); INDArray y = Nd4j.create(64, 64, 1024); @@ -95,7 +99,9 @@ public class CrashTest extends BaseNd4jTest { } @Test - public void testEWSViews2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEWSViews2(Nd4jBackend backend) { log.debug("EWS 2"); INDArray x = Nd4j.create(new int[] {96, 1024, 64}, 'f'); INDArray y = Nd4j.create(new int[] {96, 64, 1024}, 'f'); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java index 4e0c28c89..ce09ff895 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java @@ -25,9 +25,10 @@ import lombok.val; import lombok.var; import org.apache.commons.lang3.RandomUtils; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.AllocationPolicy; @@ -55,15 +56,14 @@ import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.*; @Slf4j -@RunWith(Parameterized.class) -public class SpecialTests extends BaseNd4jTest { - public SpecialTests(Nd4jBackend backend) { - super(backend); - } + +public class SpecialTests extends BaseNd4jTestWithBackends { @Test - public void testDimensionalThings1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDimensionalThings1(Nd4jBackend backend) { INDArray x = Nd4j.rand(new int[] {20, 30, 50}); INDArray y = Nd4j.rand(x.shape()); @@ -71,7 +71,9 @@ public class SpecialTests extends BaseNd4jTest { } @Test - public void testDimensionalThings2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDimensionalThings2(Nd4jBackend backend) { INDArray x = Nd4j.rand(new int[] {20, 30, 50}); INDArray y = Nd4j.rand(x.shape()); @@ -100,7 +102,7 @@ public class SpecialTests extends BaseNd4jTest { @Test() - public void testScalarShuffle1() { + public void testScalarShuffle1(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { List listData = new ArrayList<>(); for (int i = 0; i < 3; i++) { @@ -117,7 +119,9 @@ public class SpecialTests extends BaseNd4jTest { @Test - public void testScalarShuffle2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarShuffle2(Nd4jBackend backend) { List listData = new ArrayList<>(); for (int i = 0; i < 3; i++) { INDArray features = Nd4j.ones(14, 25); @@ -130,7 +134,9 @@ public class SpecialTests extends BaseNd4jTest { } @Test - public void testVstack2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVstack2(Nd4jBackend backend) { INDArray matrix = Nd4j.create(10000, 100); List views = new ArrayList<>(); @@ -142,7 +148,9 @@ public class SpecialTests extends BaseNd4jTest { } @Test - public void testVstack1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVstack1(Nd4jBackend backend) { INDArray matrix = Nd4j.create(10000, 100); List views = new ArrayList<>(); @@ -162,6 +170,8 @@ public class SpecialTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testConcatMulti() throws Exception { val shapeA = new int[] {50, 20}; val shapeB = new int[] {50, 497}; @@ -171,11 +181,8 @@ public class SpecialTests extends BaseNd4jTest { val executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(2); for (int e = 0; e < 1; e++) { - executor.submit(new Runnable() { - @Override - public void run() { - val arrayA = Nd4j.createUninitialized(shapeA); - } + executor.submit(() -> { + val arrayA = Nd4j.createUninitialized(shapeA); }); } @@ -183,18 +190,19 @@ public class SpecialTests extends BaseNd4jTest { } @Test - public void testConcatMulti2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatMulti2(Nd4jBackend backend) { Nd4j.create(1); val executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(2); - executor.submit(new Runnable() { - @Override - public void run() { + executor.submit(() -> { // System.out.println("A"); - } }); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMigrationMultiGpu_1() throws Exception { if (Nd4j.getAffinityManager().getNumberOfDevices() < 2) return; @@ -204,18 +212,15 @@ public class SpecialTests extends BaseNd4jTest { val devices = Nd4j.getAffinityManager().getNumberOfDevices(); for (int e = 0; e < devices; e++) { val f = e; - val t = new Thread(new Runnable() { - @Override - public void run() { - val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); - log.info("Current device: {}", deviceId); - for (int i = 0; i < 10; i++) { - val ar = Nd4j.create(100, 100).assign(1.0f); + val t = new Thread(() -> { + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + log.info("Current device: {}", deviceId); + for (int i = 0; i < 10; i++) { + val ar = Nd4j.create(100, 100).assign(1.0f); - assertEquals(deviceId, Nd4j.getAffinityManager().getDeviceForArray(ar)); - list.add(ar); - Nd4j.getExecutioner().commit(); - } + assertEquals(deviceId, Nd4j.getAffinityManager().getDeviceForArray(ar)); + list.add(ar); + Nd4j.getExecutioner().commit(); } }); @@ -241,6 +246,8 @@ public class SpecialTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMigrationMultiGpu_2() throws Exception { if (Nd4j.getAffinityManager().getNumberOfDevices() < 2) return; @@ -257,14 +264,11 @@ public class SpecialTests extends BaseNd4jTest { val threads = new ArrayList(); for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { val f = e; - val t = new Thread(new Runnable() { - @Override - public void run() { - for (int i = 0; i < 100; i++) { - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsConf, "id")) { - list.add(Nd4j.create(3, 3).assign(1.0f)); - Nd4j.getExecutioner().commit(); - } + val t = new Thread(() -> { + for (int i = 0; i < 100; i++) { + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsConf, "id")) { + list.add(Nd4j.create(3, 3).assign(1.0f)); + Nd4j.getExecutioner().commit(); } } }); @@ -286,6 +290,8 @@ public class SpecialTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBroadcastLt(){ for( int i=0; i<10; i++) { @@ -298,6 +304,8 @@ public class SpecialTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBroadcastLt2(){ for( int i=0; i<10; i++) { INDArray orig = Nd4j.create(DataType.DOUBLE, 1, 7, 4, 4); @@ -311,6 +319,8 @@ public class SpecialTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void reproduceWorkspaceCrash(){ val conf = WorkspaceConfiguration.builder().build(); @@ -336,6 +346,8 @@ public class SpecialTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void reproduceWorkspaceCrash_2(){ val dtypes = new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE, DataType.UBYTE, DataType.BOOL}; for (val dX : dtypes) { @@ -352,6 +364,8 @@ public class SpecialTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void reproduceWorkspaceCrash_3(){ val conf = WorkspaceConfiguration.builder().build(); @@ -373,7 +387,9 @@ public class SpecialTests extends BaseNd4jTest { } @Test - public void testCastLong_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCastLong_1(Nd4jBackend backend) { val array = Nd4j.create(DataType.LONG, 100, 100).assign(1); val second = Nd4j.create(DataType.LONG, 100, 100).assign(1); // log.info("----------------"); @@ -386,51 +402,67 @@ public class SpecialTests extends BaseNd4jTest { } @Test - public void testCastHalf_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCastHalf_1(Nd4jBackend backend) { val array = Nd4j.create(DataType.HALF, 2, 5).assign(1); assertEquals(10.f, array.sumNumber().floatValue(), 1e-3); } @Test - public void testCastHalf_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCastHalf_2(Nd4jBackend backend) { val array = Nd4j.create(DataType.HALF, 2, 5).assign(1); assertEquals(10.f, array.sumNumber().floatValue(), 1e-3); } @Test - public void testCastHalf_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCastHalf_3(Nd4jBackend backend) { val arrayY = Nd4j.create(DataType.FLOAT, 2, 5).assign(2); val arrayX = Nd4j.create(DataType.HALF, 2, 5).assign(arrayY); assertEquals(20.f, arrayX.sumNumber().floatValue(), 1e-3); } @Test - public void testReduce_Small_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduce_Small_1(Nd4jBackend backend) { val array = Nd4j.create(DataType.SHORT, 100, 30).assign(1); assertEquals(3000, array.sumNumber().intValue()); } @Test - public void testReduce_Small_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduce_Small_2(Nd4jBackend backend) { val array = Nd4j.create(DataType.BYTE, 100, 100).assign(0); assertEquals(0, array.sumNumber().intValue()); } @Test - public void testReduce3_Small_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduce3_Small_1(Nd4jBackend backend) { val arrayA = Nd4j.create(DataType.SHORT, 100, 100).assign(1); val arrayB = Nd4j.create(DataType.SHORT, 100, 100).assign(1); assertEquals(arrayA, arrayB); } @Test - public void testReduce3_Small_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduce3_Small_2(Nd4jBackend backend) { val arrayA = Nd4j.create(DataType.BYTE, 100, 100).assign(1); val arrayB = Nd4j.create(DataType.BYTE, 100, 100).assign(1); assertEquals(arrayA, arrayB); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void reproduceWorkspaceCrash_4(){ val conf = WorkspaceConfiguration.builder().build(); @@ -452,6 +484,8 @@ public class SpecialTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void reproduceWorkspaceCrash_5(){ val conf = WorkspaceConfiguration.builder().build(); @@ -471,6 +505,8 @@ public class SpecialTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testConcatAgain(){ INDArray[] toConcat = new INDArray[3]; for( int i=0; i { val arrayX = Nd4j.create(10, 10); val arrayY = Nd4j.create(10, 10); @@ -214,8 +221,11 @@ public class CustomOpsTests extends BaseNd4jTest { } + @Test - public void testNoneInplaceOp3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNoneInplaceOp3(Nd4jBackend backend) { val arrayX = Nd4j.create(10, 10); val arrayY = Nd4j.create(10, 10); @@ -234,8 +244,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, op.getOutputArgument(0)); } + @Test - public void testNoneInplaceOp4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNoneInplaceOp4(Nd4jBackend backend) { val arrayX = Nd4j.create(DataType.INT, 10, 10); val arrayY = Nd4j.create(DataType.INT, 10, 10); @@ -256,8 +269,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, res); } + @Test - public void testNoneInplaceOp5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNoneInplaceOp5(Nd4jBackend backend) { if (!Nd4j.isExperimentalMode()) return; @@ -281,8 +297,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, res); } + @Test - public void testMergeMax1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergeMax1(Nd4jBackend backend) { val array0 = Nd4j.create(new double[] {1, 0, 0, 0, 0}); val array1 = Nd4j.create(new double[] {0, 2, 0, 0, 0}); val array2 = Nd4j.create(new double[] {0, 0, 3, 0, 0}); @@ -303,8 +322,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, z); } + @Test - public void testMergeMaxF() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergeMaxF(Nd4jBackend backend) { val array0 = Nd4j.rand('f', 5, 2).add(1); //some random array with +ve numbers val array1 = array0.dup('f').add(5); @@ -324,8 +346,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, zF); } + @Test - public void testMergeMaxMixedOrder_Subtract() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergeMaxMixedOrder_Subtract(Nd4jBackend backend) { val exp = Nd4j.create(new int[] {2, 2}, 'c').assign(5.0); Nd4j.getExecutioner().commit(); @@ -337,8 +362,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, array1); } + @Test - public void testMergeMaxSameOrder_Subtract() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergeMaxSameOrder_Subtract(Nd4jBackend backend) { val exp = Nd4j.create(new int[] {2, 2}, 'c').assign(5.0); Nd4j.getExecutioner().commit(); @@ -348,8 +376,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, array1); } + @Test - public void testMergeMaxMixedOrder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergeMaxMixedOrder(Nd4jBackend backend) { val array0 = Nd4j.rand('f', 5, 2).addi(1); //some random array with +ve numbers val array1 = array0.dup('c').addi(5); array1.put(0, 0, 0); //array1 is always bigger than array0 except at 0,0 @@ -370,8 +401,11 @@ public class CustomOpsTests extends BaseNd4jTest { } + @Test - public void testOutputShapes1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOutputShapes1(Nd4jBackend backend) { val array0 = Nd4j.rand('f', 5, 2).addi(1); //some random array with +ve numbers val array1 = array0.dup().addi(5); array1.put(0, 0, 0); //array1 is always bigger than array0 except at 0,0 @@ -392,13 +426,19 @@ public class CustomOpsTests extends BaseNd4jTest { + @Test - public void testOpStatus1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOpStatus1(Nd4jBackend backend) { assertEquals(OpStatus.ND4J_STATUS_OK, OpStatus.byNumber(0)); } + @Test - public void testRandomStandardNormal_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRandomStandardNormal_1(Nd4jBackend backend) { if (Nd4j.getExecutioner().type() == OpExecutioner.ExecutionerType.CUDA) return; @@ -413,8 +453,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{5, 10}, output.shape()); } + @Test - public void testRandomStandardNormal_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRandomStandardNormal_2(Nd4jBackend backend) { if (Nd4j.getExecutioner().type() == OpExecutioner.ExecutionerType.CUDA) return; @@ -429,8 +472,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{5, 10}, output.shape()); } + @Test - public void testOpContextExecution_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOpContextExecution_1(Nd4jBackend backend) { val arrayX = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5}); val arrayY = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5}); val arrayZ = Nd4j.create(DataType.FLOAT, 5); @@ -448,8 +494,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, arrayZ); } + @Test - public void testOpContextExecution_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOpContextExecution_2(Nd4jBackend backend) { val arrayX = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5}); val arrayY = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5}); val arrayZ = Nd4j.create(DataType.FLOAT, 5); @@ -468,8 +517,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertTrue(arrayZ == output[0]); } + @Test - public void testOpContextExecution_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOpContextExecution_3(Nd4jBackend backend) { val arrayX = Nd4j.create(100); val arrayY = Nd4j.ones(100); val arrayZ = Nd4j.create(100); @@ -489,8 +541,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertTrue(arrayZ == output[0]); } + @Test - public void testFlatten_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlatten_1(Nd4jBackend backend) { val arrayA = Nd4j.createFromArray(1.f, 2.f, 3.f); val arrayB = Nd4j.createFromArray(4.f, 5.f, 6.f); val arrayC = Nd4j.createFromArray(7.f, 8.f, 9.f); @@ -502,8 +557,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, result); } + @Test - public void testMatmulBp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatmulBp(Nd4jBackend backend) { val a = Nd4j.create(DataType.DOUBLE, 1,3); val b = Nd4j.create(DataType.DOUBLE, 1,4); val gI = Nd4j.create(DataType.DOUBLE, 3,4); @@ -520,7 +578,10 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.exec(op); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testStridedSliceEdgeCase(){ INDArray in = Nd4j.scalar(10.0).reshape(1); //Int [1] INDArray begin = Nd4j.ones(DataType.INT, 1); @@ -547,7 +608,10 @@ public class CustomOpsTests extends BaseNd4jTest { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDepthwise(){ INDArray input = Nd4j.create(DataType.DOUBLE, 1,3,8,8); INDArray depthwiseWeight = Nd4j.create(DataType.DOUBLE, 1,1,3,2); @@ -572,8 +636,11 @@ public class CustomOpsTests extends BaseNd4jTest { } } + @Test - public void testMod_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMod_1(Nd4jBackend backend) { val x = Nd4j.createFromArray(5.f, 6.f, 7.f); val y = Nd4j.scalar(4.f); val e = Nd4j.createFromArray(1.f, 2.f, 3.f); @@ -583,8 +650,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(e, z); } + @Test - public void testScalarVector_edge_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarVector_edge_1(Nd4jBackend backend) { val x = Nd4j.scalar(2.0f); val y = Nd4j.createFromArray(new float[]{2.0f}); val e = Nd4j.createFromArray(new float[]{4.0f}); @@ -595,8 +665,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(e, z); } + @Test - public void testScalarVector_edge_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarVector_edge_2(Nd4jBackend backend) { val x = Nd4j.scalar(2.0f); val y = Nd4j.createFromArray(new float[]{2.0f}); val e = Nd4j.createFromArray(new float[]{4.0f}); @@ -627,7 +700,10 @@ public class CustomOpsTests extends BaseNd4jTest { } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testUpsampling2dBackprop(){ Nd4j.getRandom().setSeed(12345); @@ -671,7 +747,10 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, act); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIsMaxView(){ INDArray predictions = Nd4j.rand(DataType.FLOAT, 3, 4, 3, 2); @@ -688,7 +767,10 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(result1, result2); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void isMax4d_2dims(){ Nd4j.getRandom().setSeed(12345); INDArray in = Nd4j.rand(DataType.FLOAT, 3, 3, 4, 4).permute(0, 2, 3, 1); @@ -702,7 +784,10 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(out_dupedIn, out_permutedIn); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSizeTypes(){ List failed = new ArrayList<>(); for(DataType dt : new DataType[]{DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE, @@ -732,7 +817,10 @@ public class CustomOpsTests extends BaseNd4jTest { } } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testListDiff(){ INDArray x = Nd4j.createFromArray(0, 1, 2, 3); INDArray y = Nd4j.createFromArray(3, 1); @@ -751,7 +839,10 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, outIdx); //Indices of the values in x not in y } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTopK1(){ INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); INDArray k = Nd4j.scalar(1); @@ -772,8 +863,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expIdx, outIdx); } + @Test - public void testMaxPool2Dbp_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMaxPool2Dbp_1(Nd4jBackend backend) { val x = Nd4j.create(DataType.HALF, 2,3,16,16).assign(Double.NaN); val y = Nd4j.create(DataType.HALF, 2,3,8,8).assign(Double.NaN); val z = Nd4j.create(DataType.HALF, 2,3,16,16); @@ -788,7 +882,10 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.getExecutioner().commit(); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void test() throws Exception { INDArray in1 = Nd4j.create(DataType.BFLOAT16, 2, 3, 10, 1);//Nd4j.createFromArray(0.2019043,0.6464844,0.9116211,0.60058594,0.34033203,0.7036133,0.6772461,0.3815918,0.87353516,0.04650879,0.67822266,0.8618164,0.88378906,0.7573242,0.66796875,0.63427734,0.33764648,0.46923828,0.62939453,0.76464844,-0.8618164,-0.94873047,-0.9902344,-0.88916016,-0.86572266,-0.92089844,-0.90722656,-0.96533203,-0.97509766,-0.4975586,-0.84814453,-0.984375,-0.98828125,-0.95458984,-0.9472656,-0.91064453,-0.80859375,-0.83496094,-0.9140625,-0.82470703,0.4802246,0.45361328,0.28125,0.28320312,0.79345703,0.44604492,-0.30273438,0.11730957,0.56396484,0.73583984,0.1418457,-0.44848633,0.6923828,-0.40234375,0.40185547,0.48632812,0.14538574,0.4638672,0.13000488,0.5058594) @@ -807,8 +904,11 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.getExecutioner().commit(); } + @Test - public void testAdjustContrast() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdjustContrast(Nd4jBackend backend) { INDArray in = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4*4*3).reshape(4,4,3); INDArray out = Nd4j.zeros(DataType.DOUBLE,4, 4, 3); @@ -823,7 +923,10 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, out); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAdjustContrastShape(){ DynamicCustomOp op = DynamicCustomOp.builder("adjust_contrast_v2") .addInputs(Nd4j.create(DataType.FLOAT, 256, 256,3), Nd4j.scalar(0.5f)) @@ -834,7 +937,10 @@ public class CustomOpsTests extends BaseNd4jTest { } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBitCastShape(){ INDArray out = Nd4j.createUninitialized(1,10); BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out); @@ -843,8 +949,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{1,10,2}, lsd.get(0).getShape()); } + @Test - public void testAdjustSaturation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdjustSaturation(Nd4jBackend backend) { INDArray in = Nd4j.createFromArray(new double[]{50,100,78, 118.5,220,112.5,190,163.5,230, 255,128.5,134}).reshape(2,2,3); INDArray out = Nd4j.create(in.shape()); INDArray expected = Nd4j.createFromArray(new double[]{0,100,56, 17,220,5, 150,97,230, 255,2,13}).reshape(2,2,3); @@ -853,8 +962,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, out); } + @Test - public void testAdjustHue() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdjustHue(Nd4jBackend backend) { INDArray in = Nd4j.createFromArray(new double[]{0,100,56, 17,220,5, 150,97,230, 255,2,13}).reshape(2,2,3); INDArray out = Nd4j.create(in.shape()); INDArray expected = Nd4j.createFromArray(new double[]{100,0,44, 208,5,220, 177,230,97, 2,255,244}).reshape(2,2,3); @@ -863,8 +975,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, out); } + @Test - public void testBitCast() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBitCast(Nd4jBackend backend) { INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2); INDArray out = Nd4j.createUninitialized(2,2); @@ -877,7 +992,9 @@ public class CustomOpsTests extends BaseNd4jTest { @Test @Disabled - public void testDrawBoundingBoxesShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDrawBoundingBoxesShape(Nd4jBackend backend) { INDArray images = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f,0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f, 0.3087f,0.1548f,0.4695f,0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f, @@ -903,7 +1020,7 @@ public class CustomOpsTests extends BaseNd4jTest { @Test @Disabled("Failing with results that are close") - public void testFakeQuantAgainstTF_1() { + public void testFakeQuantAgainstTF_1(Nd4jBackend backend) { INDArray x = Nd4j.createFromArray(new double[]{ 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5); @@ -919,8 +1036,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, output[0]); } + @Test - public void testWhereFail() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testWhereFail(Nd4jBackend backend) { INDArray in = Nd4j.createFromArray(new float[]{0f, 1.0000f, 1.0000f, 1.0000f, 1.0000f}); INDArray out = Nd4j.createUninitialized(4,1); INDArray expected = Nd4j.createFromArray(4,1); @@ -929,8 +1049,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{4,1} , out.shape()); } + @Test - public void testResizeBilinear1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testResizeBilinear1(Nd4jBackend backend) { INDArray x = Nd4j.rand(1, 10,10,4); INDArray z = Nd4j.createUninitialized(x.shape()); boolean align = false; @@ -938,8 +1061,11 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.exec(op); } + @Test - public void testResizeArea1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testResizeArea1(Nd4jBackend backend) { INDArray x = Nd4j.rand(DataType.FLOAT, 1, 2,3,4); INDArray z = Nd4j.createUninitialized(DataType.FLOAT, 1, 10, 10, 4); @@ -947,8 +1073,11 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.exec(op); } + @Test - public void testResizeArea2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testResizeArea2(Nd4jBackend backend) { INDArray image = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 9 ).reshape(1,3,3,1); INDArray output = Nd4j.createUninitialized(DataType.FLOAT, 1, 6, 6, 1); @@ -967,8 +1096,11 @@ public class CustomOpsTests extends BaseNd4jTest { + @Test - public void testDivideNoNan() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDivideNoNan(Nd4jBackend backend) { INDArray in1 = Nd4j.rand(DataType.DOUBLE, 2,3,4); INDArray in2 = Nd4j.rand(DataType.DOUBLE, 2,3,4); INDArray out = Nd4j.createUninitialized(DataType.DOUBLE, 2,3,4); @@ -979,7 +1111,9 @@ public class CustomOpsTests extends BaseNd4jTest { @Test @Disabled - public void testDrawBoundingBoxes() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDrawBoundingBoxes(Nd4jBackend backend) { INDArray images = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 2*4*5*3).reshape(2,4,5,3); INDArray boxes = Nd4j.createFromArray(new float[]{ 0.0f , 0.0f , 1.0f , 1.0f, 0.1f, 0.2f, 0.9f, 0.8f, @@ -1007,8 +1141,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, output); } + @Test - public void FakeQuantWithMinMaxVarsPerChannel() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void FakeQuantWithMinMaxVarsPerChannel(Nd4jBackend backend) { INDArray x = Nd4j.createFromArray(new float[]{-63.80f, -63.75f, -63.4f, -63.5f, 0.0f, 0.1f}). reshape(1,2,3,1); @@ -1024,8 +1161,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, output[0]); } + @Test - public void testKnnMinDistance() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testKnnMinDistance(Nd4jBackend backend) { INDArray point = Nd4j.rand(DataType.FLOAT, 1, 20); INDArray lowest = Nd4j.rand(DataType.FLOAT, 1, 20); INDArray highest = Nd4j.rand(DataType.FLOAT, 1, 20); @@ -1035,8 +1175,11 @@ public class CustomOpsTests extends BaseNd4jTest { } + @Test - public void testLayersDropoutFail() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLayersDropoutFail(Nd4jBackend backend) { INDArray input = Nd4j.rand(4, 5); INDArray output = Nd4j.createUninitialized(4, 5); DropOut op = new DropOut(input, output, 0.1); @@ -1044,7 +1187,10 @@ public class CustomOpsTests extends BaseNd4jTest { // System.out.println(output); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRange(){ DynamicCustomOp op = DynamicCustomOp.builder("range") .addFloatingPointArguments(-1.0, 1.0, 0.01) @@ -1057,7 +1203,10 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.exec(op); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBitCastShape_1(){ val out = Nd4j.createUninitialized(1,10); BitCast op = new BitCast(Nd4j.zeros(DataType.FLOAT,1,10), DataType.INT.toInt(), out); @@ -1066,7 +1215,10 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape()); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBitCastShape_2(){ val out = Nd4j.createUninitialized(1,10); BitCast op = new BitCast(Nd4j.zeros(DataType.DOUBLE,1,10), DataType.INT.toInt(), out); @@ -1075,8 +1227,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape()); } + @Test - public void testFusedBatchNorm() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFusedBatchNorm(Nd4jBackend backend) { INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4); INDArray scale = Nd4j.create(DataType.DOUBLE, 4); scale.assign(0.5); @@ -1106,8 +1261,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(expectedBatchVar.shape(), batchVar.shape()); } + @Test - public void testFusedBatchNorm1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFusedBatchNorm1(Nd4jBackend backend) { INDArray x = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, @@ -1134,8 +1292,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(expectedY.shape(), y.shape()); } + @Test - public void testFusedBatchNormHalf() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFusedBatchNormHalf(Nd4jBackend backend) { INDArray x = Nd4j.create(DataType.HALF, 1,2,3,4); //INDArray scale = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f}); //INDArray offset = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f}); @@ -1151,8 +1312,11 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.exec(op); } + @Test - public void testMatrixBandPart() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatrixBandPart(Nd4jBackend backend) { INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); val op = new MatrixBandPart(x,1,1); INDArray expected = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); @@ -1166,8 +1330,11 @@ public class CustomOpsTests extends BaseNd4jTest { } @Disabled("AS failed 2019/12/04") + @Test - public void testPolygamma() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPolygamma(Nd4jBackend backend) { INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3); INDArray x = Nd4j.create(DataType.FLOAT, 3,3); x.assign(0.5); @@ -1179,8 +1346,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, output); } + @Test - public void testLgamma() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLgamma(Nd4jBackend backend) { INDArray x = Nd4j.createFromArray(new double[]{0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.}).reshape(3,3); INDArray expected = Nd4j.createFromArray(new double[]{ 2.2527127 , 0.5723649 , 0.26086727, @@ -1191,16 +1361,22 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testRandomCrop() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRandomCrop(Nd4jBackend backend) { INDArray x = Nd4j.createFromArray(new double[]{1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }).reshape(2,2,4); INDArray shape = Nd4j.createFromArray(new int[] {1,2,3}); val op = new RandomCrop(x, shape); INDArray[] res = Nd4j.exec(op); } + @Test - public void testRoll() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRoll(Nd4jBackend backend) { INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}). reshape(2,2,4,2); @@ -1214,8 +1390,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, res[0]); } + @Test - public void testToggleBits() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToggleBits(Nd4jBackend backend) { INDArray input = Nd4j.createFromArray(new int[]{2,2}); INDArray expected = Nd4j.createFromArray(new int[]{-3,-3}); ToggleBits op = new ToggleBits(input); @@ -1224,8 +1403,11 @@ public class CustomOpsTests extends BaseNd4jTest { } @Disabled("AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8449") + @Test - public void testNonMaxSuppression() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNonMaxSuppression(Nd4jBackend backend) { INDArray boxes = Nd4j.createFromArray(new float[] {0.8115f, 0.4121f, 0.0771f, 0.4863f, 0.7412f, 0.7607f, 0.1543f, 0.5479f, 0.8223f, 0.2246f, 0.0049f, 0.6465f}).reshape(3,4); @@ -1235,8 +1417,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(new long[]{1}, res[0].shape()); } + @Test - public void testMatrixBand() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatrixBand(Nd4jBackend backend) { INDArray input = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f, 0.7271f,0.1804f,0.5056f,0.8925f, 0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4); @@ -1246,8 +1431,11 @@ public class CustomOpsTests extends BaseNd4jTest { } @Disabled("Failed AS 11.26.2019 - https://github.com/eclipse/deeplearning4j/issues/8450") + @Test - public void testBetaInc1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBetaInc1(Nd4jBackend backend) { INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f}); INDArray b = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f}); INDArray c = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f}); @@ -1258,8 +1446,11 @@ public class CustomOpsTests extends BaseNd4jTest { } @Disabled("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8452") + @Test - public void testPolygamma1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPolygamma1(Nd4jBackend backend) { INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f}).reshape(3,4); @@ -1272,8 +1463,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testRoll1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRoll1(Nd4jBackend backend) { INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f}); Roll op = new Roll(a,Nd4j.scalar(2),Nd4j.scalar(0)); INDArray[] ret = Nd4j.exec(op); @@ -1285,7 +1479,10 @@ public class CustomOpsTests extends BaseNd4jTest { System.out.println(outputs[0]); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAdjustHueShape(){ INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f, @@ -1329,7 +1526,10 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{8, 8, 3}, lsd.get(0).getShape()); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBitCastShape_3(){ val x = Nd4j.createFromArray(new int[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(1, 4, 2); val e = Nd4j.createFromArray(new long[]{8589934593L, 17179869187L, 25769803781L, 34359738375L}).reshape(1, 4); @@ -1339,8 +1539,11 @@ public class CustomOpsTests extends BaseNd4jTest { } + @Test - public void testMatch_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatch_1(Nd4jBackend backend) { INDArray x = Nd4j.ones(DataType.FLOAT, 3,3); INDArray y = Nd4j.linspace(DataType.FLOAT, -5, 9, 1).reshape(3, 3); val c = Conditions.equals(0.0); @@ -1355,8 +1558,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, z); } + @Test - public void testCreateOp_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateOp_1(Nd4jBackend backend) { val shape = Nd4j.createFromArray(new int[] {3, 4, 5}); val exp = Nd4j.create(DataType.INT, 3, 4, 5); @@ -1368,7 +1574,9 @@ public class CustomOpsTests extends BaseNd4jTest { // Exact copy of libnd4j test @Test @Disabled - public void testRgbToHsv() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRgbToHsv(Nd4jBackend backend) { INDArray expected = Nd4j.createFromArray(new float[]{ 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, @@ -1403,8 +1611,11 @@ public class CustomOpsTests extends BaseNd4jTest { } // Exact copy of libnd4j test + @Test - public void testHsvToRgb() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHsvToRgb(Nd4jBackend backend) { INDArray input = Nd4j.createFromArray(new float[]{0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, 0.332347751f, 0.111181192f}).reshape(4,3); @@ -1418,8 +1629,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(ret[0], expected); } + @Test - public void testHsvToRgb_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHsvToRgb_1(Nd4jBackend backend) { /* Emulation of simple TF test: image = tf.random_uniform(shape = [1,1,3]) tf.image.hsv_to_rgb(image)*/ @@ -1432,8 +1646,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testRgbToHsv_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRgbToHsv_1(Nd4jBackend backend) { /* Emulation of simple TF test: image = tf.random_uniform(shape = [1,2,3]) tf.image.rgb_to_hsv(image)*/ @@ -1446,8 +1663,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testLu() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLu(Nd4jBackend backend) { INDArray input = Nd4j.createFromArray(new float[]{1.f, 2.f, 3.f, 0.f, 2.f, 3.f, 0.f, 0.f, 7.f}) .reshape(3,3); Lu op = new Lu(input); @@ -1457,8 +1677,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testRgbToYiq() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRgbToYiq(Nd4jBackend backend) { INDArray image = Nd4j.createFromArray(new float[]{ 0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f, 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f , @@ -1494,8 +1717,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testYiqToRgb() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testYiqToRgb(Nd4jBackend backend) { INDArray image = Nd4j.createFromArray(new float[]{ 0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, @@ -1531,8 +1757,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testRgbToGrayscale() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRgbToGrayscale(Nd4jBackend backend) { INDArray image = Nd4j.createFromArray(new float[]{ 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, @@ -1561,8 +1790,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testRgbToYuv() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRgbToYuv(Nd4jBackend backend) { INDArray image = Nd4j.createFromArray(new float[]{ 10f,50f,200f }); @@ -1576,8 +1808,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testYuvToRgb() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testYuvToRgb(Nd4jBackend backend) { INDArray image = Nd4j.createFromArray(new float[]{ 55.14f , 71.2872001f, -39.6005542f }); @@ -1590,16 +1825,22 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testRgbToYiqEmpty() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRgbToYiqEmpty(Nd4jBackend backend) { INDArray image = Nd4j.create(0,4,3); RgbToYiq op = new RgbToYiq(image); INDArray[] ret = Nd4j.exec(op); assertArrayEquals(image.shape(), ret[0].shape()); } + @Test - public void testTriangularSolve() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTriangularSolve(Nd4jBackend backend) { INDArray a = Nd4j.createFromArray(new float[]{ 3.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, @@ -1621,8 +1862,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testOnesLike_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOnesLike_1(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 3, 4, 5); val e = Nd4j.ones(DataType.INT32, 3, 4, 5); @@ -1630,16 +1874,22 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(e, z); } + @Test - public void testLinSpaceEdge_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLinSpaceEdge_1(Nd4jBackend backend) { val x = Nd4j.linspace(1,10,1, DataType.FLOAT); val e = Nd4j.scalar(1.0f); assertEquals(e, x); } + @Test - public void testLinearSolve() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLinearSolve(Nd4jBackend backend) { INDArray a = Nd4j.createFromArray(new float[]{ 2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f }).reshape(3, 3); @@ -1658,8 +1908,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testLinearSolveAdjust() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLinearSolveAdjust(Nd4jBackend backend) { INDArray a = Nd4j.createFromArray(new float[]{ 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, @@ -1684,8 +1937,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testLstsq() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLstsq(Nd4jBackend backend) { INDArray a = Nd4j.createFromArray(new float[]{ 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, @@ -1706,8 +1962,11 @@ public class CustomOpsTests extends BaseNd4jTest { } } + @Test - public void testSequenceMask() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSequenceMask(Nd4jBackend backend) { INDArray arr = Nd4j.createFromArray(new int[]{1, 3, 2}); // Test with static max len int maxlen = 2; @@ -1721,8 +1980,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + @Test - public void testCholesky() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCholesky(Nd4jBackend backend) { INDArray x = Nd4j.createFromArray(new double[] {4,12,-16, 12 ,37,-43, -16, -43, 98}).reshape(3,3); INDArray exp = Nd4j.createFromArray(new double[] {2., 0., 0., 6., 1., 0., -8., 5., 3.}).reshape(3,3); @@ -1730,8 +1992,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(res[0], exp); } + @Test - public void testQr() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testQr(Nd4jBackend backend) { INDArray in = Nd4j.createFromArray(new double[]{ 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. }).reshape(5,3); @@ -1746,7 +2011,10 @@ public class CustomOpsTests extends BaseNd4jTest { } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLinspaceSignature_1() throws Exception { val array1 = Nd4j.exec(new Linspace(DataType.FLOAT, Nd4j.scalar(1.0f), Nd4j.scalar(10.f), Nd4j.scalar(10L)))[0]; val array2 = Nd4j.exec(new Linspace(DataType.FLOAT, 1.0f, 10.f, 10L))[0]; @@ -1755,8 +2023,11 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(array1, array2); } + @Test - public void testLogdet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogdet(Nd4jBackend backend) { INDArray x = Nd4j.createFromArray(new double[]{ 4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8 }).reshape(2,3,3); @@ -1767,7 +2038,10 @@ public class CustomOpsTests extends BaseNd4jTest { } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBatchNormBpNHWC(){ //Nd4j.getEnvironment().allowHelpers(false); //Passes if helpers/MKLDNN is disabled @@ -1811,7 +2085,10 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(out1v, out2v); } + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSpaceToDepthBadStrides(){ INDArray in = Nd4j.rand(DataType.FLOAT, 2, 3, 6, 6); INDArray inBadStrides = in.permute(1,0,2,3).dup().permute(1,0,2,3); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java index eef78e1e9..ee03f8154 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/ExpandableOpsTests.java @@ -23,7 +23,9 @@ package org.nd4j.linalg.custom; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ops.compat.CompatStringSplit; import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.factory.Nd4j; @@ -33,11 +35,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @Slf4j -public class ExpandableOpsTests extends BaseNd4jTest { +public class ExpandableOpsTests extends BaseNd4jTestWithBackends { - public ExpandableOpsTests(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -45,7 +44,9 @@ public class ExpandableOpsTests extends BaseNd4jTest { } @Test - public void testCompatStringSplit_1() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCompatStringSplit_1(Nd4jBackend backend) throws Exception { val array = Nd4j.create("first string", "second"); val delimiter = Nd4j.create(" "); @@ -61,7 +62,9 @@ public class ExpandableOpsTests extends BaseNd4jTest { } @Test - public void test() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test(Nd4jBackend backend) { val arr = Nd4j.createFromArray(0, 1, 2, 3, 4, 5, 6, 7, 8).reshape(3, 3); Nd4j.exec(new PrintVariable(arr)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java index 704519e79..072419e73 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java @@ -24,7 +24,9 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4jBackend; @@ -32,18 +34,18 @@ import java.io.File; import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.Map; import static org.junit.jupiter.api.Assertions.assertTrue; -public class BalanceMinibatchesTest extends BaseNd4jTest { - public BalanceMinibatchesTest(Nd4jBackend backend) { - super(backend); - } +public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends { @Test - public void testBalance(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBalance(@TempDir Path testDir,Nd4jBackend backend) throws Exception { DataSetIterator iterator = new IrisDataSetIterator(10, 150); File minibatches = new File(testDir.toFile(),"mini-batch-dir"); @@ -60,7 +62,9 @@ public class BalanceMinibatchesTest extends BaseNd4jTest { } @Test - public void testMiniBatchBalanced(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMiniBatchBalanced(@TempDir Path testDir,Nd4jBackend backend) throws Exception { int miniBatchSize = 100; DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150); @@ -87,7 +91,7 @@ public class BalanceMinibatchesTest extends BaseNd4jTest { } - ArrayList fullBatches = new ArrayList(totalCounts.length); + List fullBatches = new ArrayList(totalCounts.length); for (int i = 0; i < totalCounts.length; i++) { fullBatches.add(iterator.totalOutcomes() * (int) totalCounts[i] / miniBatchSize); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/CachingDataSetIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/CachingDataSetIteratorTest.java index a89fc43c7..0e5928f4d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/CachingDataSetIteratorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/CachingDataSetIteratorTest.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.dataset; import org.apache.commons.io.FileUtils; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.dataset.api.iterator.CachingDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -42,12 +43,9 @@ import java.nio.file.Path; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class CachingDataSetIteratorTest extends BaseNd4jTest { - public CachingDataSetIteratorTest(Nd4jBackend backend) { - super(backend); - } +public class CachingDataSetIteratorTest extends BaseNd4jTestWithBackends { + @Override public char ordering() { @@ -55,13 +53,17 @@ public class CachingDataSetIteratorTest extends BaseNd4jTest { } @Test - public void testInMemory() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInMemory(Nd4jBackend backend) { DataSetCache cache = new InMemoryDataSetCache(); runDataSetTest(cache); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testInFile() throws IOException { Path cacheDir = Files.createTempDirectory("nd4j-data-set-cache-test"); DataSetCache cache = new InFileDataSetCache(cacheDir); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java index ee927b330..b79852d52 100755 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java @@ -26,9 +26,10 @@ import lombok.val; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; @@ -48,18 +49,13 @@ import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.indexing.NDArrayIndex.*; @Slf4j -@RunWith(Parameterized.class) -public class DataSetTest extends BaseNd4jTest { - - - - public DataSetTest(Nd4jBackend backend) { - super(backend); - } - - @Test - public void testViewIterator() { +public class DataSetTest extends BaseNd4jTestWithBackends { + + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testViewIterator(Nd4jBackend backend) { DataSetIterator iter = new ViewIterator(new IrisDataSetIterator(150, 150).next(), 10); assertTrue(iter.hasNext()); int count = 0; @@ -76,7 +72,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testViewIterator2(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testViewIterator2(Nd4jBackend backend){ INDArray f = Nd4j.linspace(1,100,100, DataType.DOUBLE).reshape('c', 10, 10); DataSet ds = new DataSet(f, f); @@ -92,7 +90,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testViewIterator3(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testViewIterator3(Nd4jBackend backend){ INDArray f = Nd4j.linspace(1,100,100, DataType.DOUBLE).reshape('c', 10, 10); DataSet ds = new DataSet(f, f); @@ -109,8 +109,10 @@ public class DataSetTest extends BaseNd4jTest { - @Test - public void testSplitTestAndTrain() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSplitTestAndTrain (Nd4jBackend backend) { INDArray labels = FeatureUtil.toOutcomeMatrix(new int[] {0, 0, 0, 0, 0, 0, 0, 0}, 1); DataSet data = new DataSet(Nd4j.rand(8, 1), labels); @@ -130,7 +132,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testSplitTestAndTrainRng() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSplitTestAndTrainRng(Nd4jBackend backend) { Random rngHere; @@ -152,7 +156,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testLabelCounts() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLabelCounts(Nd4jBackend backend) { DataSet x0 = new IrisDataSetIterator(150, 150).next(); assertEquals(0, x0.get(0).outcome(),getFailureMessage()); assertEquals( 0, x0.get(1).outcome(),getFailureMessage()); @@ -165,7 +171,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testTimeSeriesMerge() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTimeSeriesMerge(Nd4jBackend backend) { //Basic test for time series, all of the same length + no masking arrays int numExamples = 10; int inSize = 13; @@ -202,7 +210,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testTimeSeriesMergeDifferentLength() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTimeSeriesMergeDifferentLength(Nd4jBackend backend) { //Test merging of time series with different lengths -> no masking arrays on the input DataSets int numExamples = 10; @@ -295,7 +305,9 @@ public class DataSetTest extends BaseNd4jTest { @Test - public void testTimeSeriesMergeWithMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTimeSeriesMergeWithMasking(Nd4jBackend backend) { //Test merging of time series with (a) different lengths, and (b) mask arrays in the input DataSets int numExamples = 10; @@ -404,7 +416,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testCnnMerge() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCnnMerge (Nd4jBackend backend) { //Test merging of CNN data sets int nOut = 3; int width = 5; @@ -483,7 +497,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testCnnMergeFeatureMasks() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCnnMergeFeatureMasks(Nd4jBackend backend) { //Tests merging of different CNN masks: [mb,1,h,1], [mb,1,1,w], [mb,1,h,w] for( int t=0; t<3; t++) { @@ -600,7 +616,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testMixedRnn2dMerging() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMixedRnn2dMerging (Nd4jBackend backend) { //RNN input with 2d label output //Basic test for time series, all of the same length + no masking arrays int numExamples = 10; @@ -638,7 +656,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testMergingWithPerOutputMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergingWithPerOutputMasking (Nd4jBackend backend) { //Test 2d mask merging, 2d data //features @@ -711,7 +731,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testShuffle4d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShuffle4d(Nd4jBackend backend) { int nSamples = 10; int nChannels = 3; int imgRows = 4; @@ -742,7 +764,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testShuffleNd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShuffleNd(Nd4jBackend backend) { int numDims = 7; int nLabels = 3; Random r = new Random(); @@ -792,7 +816,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testShuffleMeta() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShuffleMeta(Nd4jBackend backend) { int nExamples = 20; int nColumns = 4; @@ -826,7 +852,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testLabelNames() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLabelNames(Nd4jBackend backend) { List names = Arrays.asList("label1", "label2", "label3", "label0"); INDArray features = Nd4j.ones(10); INDArray labels = Nd4j.linspace(0, 3, 4, DataType.DOUBLE); @@ -838,7 +866,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testToString() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToString(Nd4jBackend backend) { org.nd4j.linalg.dataset.api.DataSet ds = new DataSet(); //this should not throw a null pointer // System.out.println(ds); @@ -865,7 +895,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testGetRangeMask() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRangeMask(Nd4jBackend backend) { org.nd4j.linalg.dataset.api.DataSet ds = new DataSet(); //Checking printing of masks int numExamples = 10; @@ -894,7 +926,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testAsList() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAsList(Nd4jBackend backend) { org.nd4j.linalg.dataset.api.DataSet ds; //Comparing merge with asList int numExamples = 10; @@ -930,7 +964,9 @@ public class DataSetTest extends BaseNd4jTest { @Test - public void testDataSetSaveLoad() throws IOException { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDataSetSaveLoad(Nd4jBackend backend) throws IOException { boolean[] b = new boolean[] {true, false}; @@ -979,7 +1015,9 @@ public class DataSetTest extends BaseNd4jTest { @Test - public void testDataSetSaveLoadSingle() throws IOException { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDataSetSaveLoadSingle(Nd4jBackend backend) throws IOException { INDArray f = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape('c', 4, 3, 2); INDArray l = Nd4j.linspace(24, 48, 24, DataType.DOUBLE).reshape('c', 4, 3, 2); @@ -1017,7 +1055,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testMdsShuffle(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMdsShuffle(Nd4jBackend backend) { MultiDataSet orig = new MultiDataSet(Nd4j.linspace(1,100,100, DataType.DOUBLE).reshape('c',10,10), Nd4j.linspace(100,200,100, DataType.DOUBLE).reshape('c',10,10)); @@ -1054,7 +1094,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testSample4d(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSample4d(Nd4jBackend backend) { Nd4j.getRandom().setSeed(12345); int next1 = Nd4j.getRandom().nextInt(4); int next2 = Nd4j.getRandom().nextInt(4); @@ -1062,7 +1104,7 @@ public class DataSetTest extends BaseNd4jTest { assertNotEquals(next1, next2); INDArray arr = Nd4j.create(DataType.DOUBLE, 4,1,5,5); - for( int i=0; i<4; i++ ){ + for( int i = 0; i < 4; i++) { arr.get(point(i), all(), all(), all()).assign(i); } @@ -1079,7 +1121,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testDataSetMetaDataSerialization(@TempDir Path testDir) throws IOException { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend backend) throws IOException { for(boolean withMeta : new boolean[]{false, true}) { // create simple data set with meta data object @@ -1109,7 +1153,9 @@ public class DataSetTest extends BaseNd4jTest { } @Test - public void testMultiDataSetMetaDataSerialization(@TempDir Path testDir) throws IOException { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultiDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend nd4jBackend) throws IOException { for(boolean withMeta : new boolean[]{false, true}) { // create simple data set with meta data object diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java index 8c43c5f30..cdabf5cdb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator; @@ -37,14 +38,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -@RunWith(Parameterized.class) -public class ImagePreProcessortTest extends BaseNd4jTest { - public ImagePreProcessortTest(Nd4jBackend backend) { - super(backend); - } + +public class ImagePreProcessortTest extends BaseNd4jTestWithBackends { @Test - public void simpleImageTest() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void simpleImageTest(Nd4jBackend backend) { INDArray rChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(128); INDArray gChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(64); INDArray bChannels = Nd4j.zeros(DataType.FLOAT, 10, 10).addi(255); @@ -104,7 +104,9 @@ public class ImagePreProcessortTest extends BaseNd4jTest { } @Test - public void simpleImageTestMulti() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void simpleImageTestMulti(Nd4jBackend backend) { INDArray rChannels = Nd4j.zeros(10, 10).addi(128); INDArray gChannels = Nd4j.zeros(10, 10).addi(64); INDArray bChannels = Nd4j.zeros(10, 10).addi(255); @@ -160,7 +162,9 @@ public class ImagePreProcessortTest extends BaseNd4jTest { @Test - public void testSegmentation(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSegmentation(Nd4jBackend backend){ INDArray f = Nd4j.math().floor(Nd4j.rand(DataType.FLOAT, 3, 3, 16, 16).muli(255)); INDArray l = Nd4j.math().floor(Nd4j.rand(DataType.FLOAT, 3, 10, 8, 8).muli(255)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java index 95ea38171..fc21524d9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.KFoldIterator; @@ -34,55 +35,56 @@ import java.util.HashSet; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class KFoldIteratorTest extends BaseNd4jTest { - public KFoldIteratorTest(Nd4jBackend backend) { - super(backend); +public class KFoldIteratorTest extends BaseNd4jTestWithBackends { + + + + /** + * Try every possible k number of folds from 2 to the number of examples, + * and check that every example will be exactly once in the test set, + * and the sum of the number of test examples in all folds equals to the number of examples. + */ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void checkTestFoldContent(Nd4jBackend backend) { + + final int numExamples = 42; + final int numFeatures = 3; + INDArray features = Nd4j.rand(new int[] {numExamples, numFeatures}); + INDArray labels = Nd4j.linspace(1, numExamples, numExamples, DataType.DOUBLE).reshape(-1, 1); + + DataSet dataSet = new DataSet(features, labels); + + for (int k = 2; k <= numExamples; k++) { + KFoldIterator kFoldIterator = new KFoldIterator(k, dataSet); + HashSet testLabels = new HashSet(); + for (int i = 0; i < k; i++) { + kFoldIterator.next(); + DataSet testFold = kFoldIterator.testFold(); + for (DataSet testExample : testFold) { + /** + * Check that the current example has not been in the test set before + */ + INDArray testedLabel = testExample.getLabels(); + assertTrue(testLabels.add(testedLabel.getDouble(0))); + } + } + /** + * Check that the sum of the number of test examples in all folds equals to the number of examples + */ + assertEquals(numExamples, testLabels.size()); + } } - - /** - * Try every possible k number of folds from 2 to the number of examples, - * and check that every example will be exactly once in the test set, - * and the sum of the number of test examples in all folds equals to the number of examples. - */ - @Test - public void checkTestFoldContent() { - - final int numExamples = 42; - final int numFeatures = 3; - INDArray features = Nd4j.rand(new int[] {numExamples, numFeatures}); - INDArray labels = Nd4j.linspace(1, numExamples, numExamples, DataType.DOUBLE).reshape(-1, 1); - - DataSet dataSet = new DataSet(features, labels); - - for (int k = 2; k <= numExamples; k++) { - KFoldIterator kFoldIterator = new KFoldIterator(k, dataSet); - HashSet testLabels = new HashSet(); - for (int i = 0; i < k; i++) { - kFoldIterator.next(); - DataSet testFold = kFoldIterator.testFold(); - for (DataSet testExample : testFold) { - /** - * Check that the current example has not been in the test set before - */ - INDArray testedLabel = testExample.getLabels(); - assertTrue(testLabels.add(testedLabel.getDouble(0))); - } - } - /** - * Check that the sum of the number of test examples in all folds equals to the number of examples - */ - assertEquals(numExamples, testLabels.size()); - } - } - @Test - public void checkFolds() { - // Expected batch sizes: 3+3+3+2 = 11 total examples - int[] batchSizesExp = new int[] {3, 3, 3, 2}; + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void checkFolds(Nd4jBackend backend) { + // Expected batch sizes: 3+3+3+2 = 11 total examples + int[] batchSizesExp = new int[] {3, 3, 3, 2}; KBatchRandomDataSet randomDS = new KBatchRandomDataSet(new int[] {2, 3}, batchSizesExp); DataSet allData = randomDS.getAllBatches(); KFoldIterator kiter = new KFoldIterator(4, allData); @@ -98,16 +100,16 @@ public class KFoldIteratorTest extends BaseNd4jTest { assertEquals(randomDS.getBatchK(i, true), test.getFeatures()); assertEquals(randomDS.getBatchK(i, false), test.getLabels()); - + assertEquals(batchSizesExp[i], test.getLabels().length()); i++; } assertEquals(i, 4); } - + @Test() - public void checkCornerCaseException() { + public void checkCornerCaseException(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1), Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1)); @@ -119,9 +121,11 @@ public class KFoldIteratorTest extends BaseNd4jTest { } @Test - public void checkCornerCase() { - // Expected batch sizes: 2+1 = 3 total examples - int[] batchSizesExp = new int[] {2, 1}; + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void checkCornerCase(Nd4jBackend backend) { + // Expected batch sizes: 2+1 = 3 total examples + int[] batchSizesExp = new int[] {2, 1}; KBatchRandomDataSet randomDS = new KBatchRandomDataSet(new int[] {2, 3}, batchSizesExp); DataSet allData = randomDS.getAllBatches(); KFoldIterator kiter = new KFoldIterator(2, allData); @@ -135,14 +139,14 @@ public class KFoldIteratorTest extends BaseNd4jTest { assertEquals(randomDS.getBatchK(i, true), test.getFeatures()); assertEquals(randomDS.getBatchK(i, false), test.getLabels()); - + assertEquals(batchSizesExp[i], test.getLabels().length()); i++; } assertEquals(i, 2); } - + /** * Dataset built from given sized batches of random data * @author susaneraly created RandomDataSet @@ -225,12 +229,14 @@ public class KFoldIteratorTest extends BaseNd4jTest { return batches; } } - - + + @Test - public void test5974(){ - DataSet ds = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1), - Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1)); + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test5974(Nd4jBackend backend){ + DataSet ds = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1), + Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1)); KFoldIterator iter = new KFoldIterator(10, ds); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MinMaxStatsTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MinMaxStatsTest.java index 857e95c6d..adbf82aae 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MinMaxStatsTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MinMaxStatsTest.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats; import org.nd4j.linalg.factory.Nd4j; @@ -34,21 +35,20 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Ede Meijer */ -@RunWith(Parameterized.class) -public class MinMaxStatsTest extends BaseNd4jTest { - public MinMaxStatsTest(Nd4jBackend backend) { - super(backend); - } + +public class MinMaxStatsTest extends BaseNd4jTestWithBackends { @Test - public void testEnforcingNonZeroRange() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEnforcingNonZeroRange(Nd4jBackend backend) { INDArray lower = Nd4j.create(new double[] {2, 3, 4, 5}); MinMaxStats stats = new MinMaxStats(lower.dup(), - Nd4j.create(new double[] {8, 3, 3.9, 5 + Nd4j.EPS_THRESHOLD * 0.5})); + Nd4j.create(new double[] {8, 3, 3.9, 5 + Nd4j.EPS_THRESHOLD * 0.5})); INDArray expectedUpper = Nd4j.create( - new double[] {8, 3 + Nd4j.EPS_THRESHOLD, 4 + Nd4j.EPS_THRESHOLD, 5 + Nd4j.EPS_THRESHOLD}); + new double[] {8, 3 + Nd4j.EPS_THRESHOLD, 4 + Nd4j.EPS_THRESHOLD, 5 + Nd4j.EPS_THRESHOLD}); assertEquals(lower, stats.getLower()); assertEquals(expectedUpper, stats.getUpper()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java index 3391af730..b39b7c90d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java @@ -24,27 +24,24 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.factory.Nd4jBackend; import java.nio.file.Path; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) -public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTest { - - public MiniBatchFileDataSetIteratorTest(Nd4jBackend backend) { - super(backend); - } +public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTestWithBackends { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMiniBatches(@TempDir Path testDir) throws Exception { DataSet load = new IrisDataSetIterator(150, 150).next(); final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.toFile()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java index ced615d55..64391e818 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiDataSetTest.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.dataset; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; @@ -44,14 +45,13 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; @Slf4j -@RunWith(Parameterized.class) -public class MultiDataSetTest extends BaseNd4jTest { - public MultiDataSetTest(Nd4jBackend backend) { - super(backend); - } + +public class MultiDataSetTest extends BaseNd4jTestWithBackends { @Test - public void testMerging2d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMerging2d(Nd4jBackend backend) { //Simple test: single input/output arrays; 5 MultiDataSets to merge int nCols = 3; int nRows = 5; @@ -79,7 +79,9 @@ public class MultiDataSetTest extends BaseNd4jTest { } @Test - public void testMerging2dMultipleInOut() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMerging2dMultipleInOut(Nd4jBackend backend) { //Test merging: Multiple input/output arrays; 5 MultiDataSets to merge int nRows = 5; @@ -123,7 +125,9 @@ public class MultiDataSetTest extends BaseNd4jTest { } @Test - public void testMerging2dMultipleInOut2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMerging2dMultipleInOut2(Nd4jBackend backend) { //Test merging: Multiple input/output arrays; 5 MultiDataSets to merge int nRows = 10; @@ -177,7 +181,9 @@ public class MultiDataSetTest extends BaseNd4jTest { } @Test - public void testMerging2dMultipleInOut3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMerging2dMultipleInOut3(Nd4jBackend backend) { //Test merging: fewer rows than output arrays... int nRows = 2; @@ -219,7 +225,9 @@ public class MultiDataSetTest extends BaseNd4jTest { } @Test - public void testMerging4dMultipleInOut() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMerging4dMultipleInOut(Nd4jBackend backend) { int nRows = 5; int depthIn0 = 3; int widthIn0 = 4; @@ -244,18 +252,18 @@ public class MultiDataSetTest extends BaseNd4jTest { if (i == 0) { //For first MultiDataSet: have 2 rows, not just 1 INDArray in0 = expIn0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.all()).dup(); + NDArrayIndex.all()).dup(); INDArray in1 = expIn1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.all()).dup(); + NDArrayIndex.all()).dup(); INDArray out0 = expOut0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup(); INDArray out1 = expOut1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all()).dup(); list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1})); i++; } else { INDArray in0 = expIn0.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.all()).dup(); + NDArrayIndex.all()).dup(); INDArray in1 = expIn1.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.all()).dup(); + NDArrayIndex.all()).dup(); INDArray out0 = expOut0.getRow(i, true).dup(); INDArray out1 = expOut1.getRow(i, true).dup(); list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1})); @@ -273,7 +281,9 @@ public class MultiDataSetTest extends BaseNd4jTest { } @Test - public void testMergingTimeSeriesEqualLength() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergingTimeSeriesEqualLength(Nd4jBackend backend) { int tsLength = 8; int nRows = 5; int nColsIn0 = 3; @@ -295,24 +305,24 @@ public class MultiDataSetTest extends BaseNd4jTest { if (i == 0) { //For first MultiDataSet: have 2 rows, not just 1 INDArray in0 = expIn0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all()) - .dup(); + .dup(); INDArray in1 = expIn1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all()) - .dup(); + .dup(); INDArray out0 = expOut0.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all()) - .dup(); + .dup(); INDArray out1 = expOut1.get(NDArrayIndex.interval(0, 1, true), NDArrayIndex.all(), NDArrayIndex.all()) - .dup(); + .dup(); list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1})); i++; } else { INDArray in0 = expIn0.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all()) - .dup(); + .dup(); INDArray in1 = expIn1.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all()) - .dup(); + .dup(); INDArray out0 = expOut0.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all()) - .dup(); + .dup(); INDArray out1 = expOut1.get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all()) - .dup(); + .dup(); list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1})); } } @@ -328,7 +338,9 @@ public class MultiDataSetTest extends BaseNd4jTest { } @Test - public void testMergingTimeSeriesWithMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergingTimeSeriesWithMasking(Nd4jBackend backend) { //Mask arrays, and different lengths int tsLengthIn0 = 8; @@ -387,27 +399,27 @@ public class MultiDataSetTest extends BaseNd4jTest { } expectedIn0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(), - NDArrayIndex.interval(0, thisRowIn0Length)}, in0); + NDArrayIndex.interval(0, thisRowIn0Length)}, in0); expectedIn1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(), - NDArrayIndex.interval(0, thisRowIn1Length)}, in1); + NDArrayIndex.interval(0, thisRowIn1Length)}, in1); expectedOut0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(), - NDArrayIndex.interval(0, thisRowOut0Length)}, out0); + NDArrayIndex.interval(0, thisRowOut0Length)}, out0); expectedOut1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(), - NDArrayIndex.interval(0, thisRowOut1Length)}, out1); + NDArrayIndex.interval(0, thisRowOut1Length)}, out1); expectedMaskIn0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowIn0Length)}, - Nd4j.ones(1, thisRowIn0Length)); + Nd4j.ones(1, thisRowIn0Length)); expectedMaskIn1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowIn1Length)}, - maskIn1); + maskIn1); expectedMaskOut0.put( - new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut0Length)}, - Nd4j.ones(1, thisRowOut0Length)); + new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut0Length)}, + Nd4j.ones(1, thisRowOut0Length)); expectedMaskOut1.put( - new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut1Length)}, - maskOut1); + new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut1Length)}, + maskOut1); list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1}, - new INDArray[] {maskIn0, maskIn1}, new INDArray[] {maskOut0, maskOut1})); + new INDArray[] {maskIn0, maskIn1}, new INDArray[] {maskOut0, maskOut1})); } MultiDataSet merged = MultiDataSet.merge(list); @@ -429,7 +441,9 @@ public class MultiDataSetTest extends BaseNd4jTest { } @Test - public void testMergingWithPerOutputMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMergingWithPerOutputMasking(Nd4jBackend backend) { //Test 2d mask merging, 2d data //features @@ -478,14 +492,14 @@ public class MultiDataSetTest extends BaseNd4jTest { INDArray expLabels3d = Nd4j.create(3, 3, 4); expLabels3d.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, - l3d1); + l3d1); expLabels3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(), - NDArrayIndex.interval(0, 3)}, l3d2); + NDArrayIndex.interval(0, 3)}, l3d2); INDArray expLM3d = Nd4j.create(3, 3, 4); expLM3d.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, - lm3d1); + lm3d1); expLM3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(), - NDArrayIndex.interval(0, 3)}, lm3d2); + NDArrayIndex.interval(0, 3)}, lm3d2); MultiDataSet merged3d = MultiDataSet.merge(Arrays.asList(mds3d1, mds3d2)); @@ -502,7 +516,9 @@ public class MultiDataSetTest extends BaseNd4jTest { } @Test - public void testSplit() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSplit(Nd4jBackend backend) { INDArray[] features = new INDArray[3]; features[0] = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape('c', 3, 10); @@ -537,9 +553,9 @@ public class MultiDataSetTest extends BaseNd4jTest { assertEquals(features[0].get(NDArrayIndex.interval(i,i,true), NDArrayIndex.all()), m.getFeatures(0)); assertEquals(features[1].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all()), - m.getFeatures(1)); + m.getFeatures(1)); assertEquals(features[2].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.all()), m.getFeatures(2)); + NDArrayIndex.all()), m.getFeatures(2)); assertEquals(2, m.getLabels(0).rank()); assertEquals(3, m.getLabels(1).rank()); @@ -551,9 +567,9 @@ public class MultiDataSetTest extends BaseNd4jTest { assertEquals(labels[0].get(NDArrayIndex.interval(i,i,true), NDArrayIndex.all()), m.getLabels(0)); assertEquals(labels[1].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all()), - m.getLabels(1)); + m.getLabels(1)); assertEquals(labels[2].get(NDArrayIndex.interval(i, i, true), NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.all()), m.getLabels(2)); + NDArrayIndex.all()), m.getLabels(2)); assertNull(m.getFeaturesMaskArray(0)); assertEquals(fMask[1].get(NDArrayIndex.interval(i,i,true), NDArrayIndex.all()), m.getFeaturesMaskArray(1)); @@ -564,7 +580,9 @@ public class MultiDataSetTest extends BaseNd4jTest { } @Test - public void testToString() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToString(Nd4jBackend backend) { //Mask arrays, and different lengths int tsLengthIn0 = 8; @@ -623,27 +641,27 @@ public class MultiDataSetTest extends BaseNd4jTest { } expectedIn0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(), - NDArrayIndex.interval(0, thisRowIn0Length)}, in0); + NDArrayIndex.interval(0, thisRowIn0Length)}, in0); expectedIn1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(), - NDArrayIndex.interval(0, thisRowIn1Length)}, in1); + NDArrayIndex.interval(0, thisRowIn1Length)}, in1); expectedOut0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(), - NDArrayIndex.interval(0, thisRowOut0Length)}, out0); + NDArrayIndex.interval(0, thisRowOut0Length)}, out0); expectedOut1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(), - NDArrayIndex.interval(0, thisRowOut1Length)}, out1); + NDArrayIndex.interval(0, thisRowOut1Length)}, out1); expectedMaskIn0.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowIn0Length)}, - Nd4j.ones(1, thisRowIn0Length)); + Nd4j.ones(1, thisRowIn0Length)); expectedMaskIn1.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowIn1Length)}, - maskIn1); + maskIn1); expectedMaskOut0.put( - new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut0Length)}, - Nd4j.ones(1, thisRowOut0Length)); + new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut0Length)}, + Nd4j.ones(1, thisRowOut0Length)); expectedMaskOut1.put( - new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut1Length)}, - maskOut1); + new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.interval(0, thisRowOut1Length)}, + maskOut1); list.add(new MultiDataSet(new INDArray[] {in0, in1}, new INDArray[] {out0, out1}, - new INDArray[] {maskIn0, maskIn1}, new INDArray[] {maskOut0, maskOut1})); + new INDArray[] {maskIn0, maskIn1}, new INDArray[] {maskOut0, maskOut1})); } MultiDataSet merged = MultiDataSet.merge(list); @@ -651,6 +669,8 @@ public class MultiDataSetTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void multiDataSetSaveLoadTest() throws IOException { int max = 3; @@ -706,7 +726,9 @@ public class MultiDataSetTest extends BaseNd4jTest { } @Test - public void testCnnMergeFeatureMasks() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCnnMergeFeatureMasks(Nd4jBackend backend) { //Tests merging of different CNN masks: [mb,1,h,1], [mb,1,1,w], [mb,1,h,w] for( int t=0; t<3; t++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerHybridTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerHybridTest.java index 020b524a7..58bc669de 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerHybridTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerHybridTest.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerHybrid; import org.nd4j.linalg.factory.Nd4j; @@ -32,8 +33,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) -public class MultiNormalizerHybridTest extends BaseNd4jTest { + +public class MultiNormalizerHybridTest extends BaseNd4jTestWithBackends { private MultiNormalizerHybrid SUT; private MultiDataSet data; private MultiDataSet dataCopy; @@ -42,19 +43,18 @@ public class MultiNormalizerHybridTest extends BaseNd4jTest { public void setUp() { SUT = new MultiNormalizerHybrid(); data = new MultiDataSet( - new INDArray[] {Nd4j.create(new float[][] {{1, 2}, {3, 4}}), - Nd4j.create(new float[][] {{3, 4}, {5, 6}}),}, - new INDArray[] {Nd4j.create(new float[][] {{10, 11}, {12, 13}}), - Nd4j.create(new float[][] {{14, 15}, {16, 17}}),}); + new INDArray[] {Nd4j.create(new float[][] {{1, 2}, {3, 4}}), + Nd4j.create(new float[][] {{3, 4}, {5, 6}}),}, + new INDArray[] {Nd4j.create(new float[][] {{10, 11}, {12, 13}}), + Nd4j.create(new float[][] {{14, 15}, {16, 17}}),}); dataCopy = data.copy(); } - public MultiNormalizerHybridTest(Nd4jBackend backend) { - super(backend); - } @Test - public void testNoNormalizationByDefault() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNoNormalizationByDefault(Nd4jBackend backend) { SUT.fit(data); SUT.preProcess(data); assertEquals(dataCopy, data); @@ -64,15 +64,17 @@ public class MultiNormalizerHybridTest extends BaseNd4jTest { } @Test - public void testGlobalNormalization() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGlobalNormalization(Nd4jBackend backend) { SUT.standardizeAllInputs().minMaxScaleAllOutputs(-10, 10).fit(data); SUT.preProcess(data); MultiDataSet expected = new MultiDataSet( - new INDArray[] {Nd4j.create(new float[][] {{-1, -1}, {1, 1}}), - Nd4j.create(new float[][] {{-1, -1}, {1, 1}}),}, - new INDArray[] {Nd4j.create(new float[][] {{-10, -10}, {10, 10}}), - Nd4j.create(new float[][] {{-10, -10}, {10, 10}}),}); + new INDArray[] {Nd4j.create(new float[][] {{-1, -1}, {1, 1}}), + Nd4j.create(new float[][] {{-1, -1}, {1, 1}}),}, + new INDArray[] {Nd4j.create(new float[][] {{-10, -10}, {10, 10}}), + Nd4j.create(new float[][] {{-10, -10}, {10, 10}}),}); assertEquals(expected, data); @@ -81,15 +83,17 @@ public class MultiNormalizerHybridTest extends BaseNd4jTest { } @Test - public void testSpecificInputOutputNormalization() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSpecificInputOutputNormalization(Nd4jBackend backend) { SUT.minMaxScaleAllInputs().standardizeInput(1).standardizeOutput(0).fit(data); SUT.preProcess(data); MultiDataSet expected = new MultiDataSet( - new INDArray[] {Nd4j.create(new float[][] {{0, 0}, {1, 1}}), - Nd4j.create(new float[][] {{-1, -1}, {1, 1}}),}, - new INDArray[] {Nd4j.create(new float[][] {{-1, -1}, {1, 1}}), - Nd4j.create(new float[][] {{14, 15}, {16, 17}}),}); + new INDArray[] {Nd4j.create(new float[][] {{0, 0}, {1, 1}}), + Nd4j.create(new float[][] {{-1, -1}, {1, 1}}),}, + new INDArray[] {Nd4j.create(new float[][] {{-1, -1}, {1, 1}}), + Nd4j.create(new float[][] {{14, 15}, {16, 17}}),}); assertEquals(expected, data); @@ -98,22 +102,24 @@ public class MultiNormalizerHybridTest extends BaseNd4jTest { } @Test - public void testMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMasking(Nd4jBackend backend) { MultiDataSet timeSeries = new MultiDataSet( - new INDArray[] {Nd4j.create(new float[] {1, 2, 3, 4, 5, 0, 7, 0}).reshape(2, 2, 2),}, - new INDArray[] {Nd4j.create(new float[] {0, 20, 0, 40, 50, 60, 70, 80}).reshape(2, 2, 2)}, - new INDArray[] {Nd4j.create(new float[][] {{1, 1}, {1, 0}})}, - new INDArray[] {Nd4j.create(new float[][] {{0, 1}, {1, 1}})}); + new INDArray[] {Nd4j.create(new float[] {1, 2, 3, 4, 5, 0, 7, 0}).reshape(2, 2, 2),}, + new INDArray[] {Nd4j.create(new float[] {0, 20, 0, 40, 50, 60, 70, 80}).reshape(2, 2, 2)}, + new INDArray[] {Nd4j.create(new float[][] {{1, 1}, {1, 0}})}, + new INDArray[] {Nd4j.create(new float[][] {{0, 1}, {1, 1}})}); MultiDataSet timeSeriesCopy = timeSeries.copy(); SUT.minMaxScaleAllInputs(-10, 10).minMaxScaleAllOutputs(-10, 10).fit(timeSeries); SUT.preProcess(timeSeries); MultiDataSet expected = new MultiDataSet( - new INDArray[] {Nd4j.create(new float[] {-10, -5, -10, -5, 10, 0, 10, 0}).reshape(2, 2, 2),}, - new INDArray[] {Nd4j.create(new float[] {0, -10, 0, -10, 5, 10, 5, 10}).reshape(2, 2, 2),}, - new INDArray[] {Nd4j.create(new float[][] {{1, 1}, {1, 0}})}, - new INDArray[] {Nd4j.create(new float[][] {{0, 1}, {1, 1}})}); + new INDArray[] {Nd4j.create(new float[] {-10, -5, -10, -5, 10, 0, 10, 0}).reshape(2, 2, 2),}, + new INDArray[] {Nd4j.create(new float[] {0, -10, 0, -10, 5, 10, 5, 10}).reshape(2, 2, 2),}, + new INDArray[] {Nd4j.create(new float[][] {{1, 1}, {1, 0}})}, + new INDArray[] {Nd4j.create(new float[][] {{0, 1}, {1, 1}})}); assertEquals(expected, timeSeries); @@ -123,7 +129,9 @@ public class MultiNormalizerHybridTest extends BaseNd4jTest { } @Test - public void testDataSetWithoutLabels() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDataSetWithoutLabels(Nd4jBackend backend) { SUT.standardizeAllInputs().standardizeAllOutputs().fit(data); data.setLabels(null); @@ -133,7 +141,9 @@ public class MultiNormalizerHybridTest extends BaseNd4jTest { } @Test - public void testDataSetWithoutFeatures() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDataSetWithoutFeatures(Nd4jBackend backend) { SUT.standardizeAllInputs().standardizeAllOutputs().fit(data); data.setFeatures(null); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerMinMaxScalerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerMinMaxScalerTest.java index da87004a5..48c71d4c3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerMinMaxScalerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerMinMaxScalerTest.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.TestMultiDataSetIterator; @@ -35,8 +36,8 @@ import org.nd4j.linalg.ops.transforms.Transforms; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class MultiNormalizerMinMaxScalerTest extends BaseNd4jTest { + +public class MultiNormalizerMinMaxScalerTest extends BaseNd4jTestWithBackends { private static final double TOLERANCE_PERC = 0.01; // 0.01% of correct value private static final int INPUT1_SCALE = 1, INPUT2_SCALE = 2, OUTPUT1_SCALE = 3, OUTPUT2_SCALE = 4; @@ -66,25 +67,28 @@ public class MultiNormalizerMinMaxScalerTest extends BaseNd4jTest { naturalMax = nSamples; } - public MultiNormalizerMinMaxScalerTest(Nd4jBackend backend) { - super(backend); - } @Test - public void testMultipleInputsAndOutputsWithDataSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultipleInputsAndOutputsWithDataSet(Nd4jBackend backend) { SUT.fit(data); assertExpectedMinMax(); } @Test - public void testMultipleInputsAndOutputsWithIterator() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultipleInputsAndOutputsWithIterator(Nd4jBackend backend) { MultiDataSetIterator iter = new TestMultiDataSetIterator(1, data); SUT.fit(iter); assertExpectedMinMax(); } @Test - public void testRevertFeaturesINDArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRevertFeaturesINDArray(Nd4jBackend backend) { SUT.fit(data); MultiDataSet transformed = data.copy(); @@ -100,7 +104,9 @@ public class MultiNormalizerMinMaxScalerTest extends BaseNd4jTest { } @Test - public void testRevertLabelsINDArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRevertLabelsINDArray(Nd4jBackend backend) { SUT.fit(data); MultiDataSet transformed = data.copy(); @@ -116,7 +122,9 @@ public class MultiNormalizerMinMaxScalerTest extends BaseNd4jTest { } @Test - public void testRevertMultiDataSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRevertMultiDataSet(Nd4jBackend backend) { SUT.fit(data); MultiDataSet transformed = data.copy(); @@ -132,13 +140,15 @@ public class MultiNormalizerMinMaxScalerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testFullyMaskedData() { MultiDataSetIterator iter = new TestMultiDataSetIterator(1, - new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {1}).reshape(1, 1, 1)}, - new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)}), - new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)}, - new INDArray[] {Nd4j.create(new float[] {4}).reshape(1, 1, 1)}, null, - new INDArray[] {Nd4j.create(new float[] {0}).reshape(1, 1)})); + new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {1}).reshape(1, 1, 1)}, + new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)}), + new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)}, + new INDArray[] {Nd4j.create(new float[] {4}).reshape(1, 1, 1)}, null, + new INDArray[] {Nd4j.create(new float[] {0}).reshape(1, 1)})); SUT.fit(iter); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerStandardizeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerStandardizeTest.java index 899c96b46..8f3a40d18 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerStandardizeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MultiNormalizerStandardizeTest.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.TestMultiDataSetIterator; @@ -35,8 +36,8 @@ import org.nd4j.linalg.ops.transforms.Transforms; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class MultiNormalizerStandardizeTest extends BaseNd4jTest { + +public class MultiNormalizerStandardizeTest extends BaseNd4jTestWithBackends { private static final double TOLERANCE_PERC = 0.01; // 0.01% of correct value private static final int INPUT1_SCALE = 1, INPUT2_SCALE = 2, OUTPUT1_SCALE = 3, OUTPUT2_SCALE = 4; @@ -65,25 +66,28 @@ public class MultiNormalizerStandardizeTest extends BaseNd4jTest { stdNaturalNums = Math.sqrt((nSamples * nSamples - 1) / 12.0); } - public MultiNormalizerStandardizeTest(Nd4jBackend backend) { - super(backend); - } @Test - public void testMultipleInputsAndOutputsWithDataSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultipleInputsAndOutputsWithDataSet(Nd4jBackend backend) { SUT.fit(data); assertExpectedMeanStd(); } @Test - public void testMultipleInputsAndOutputsWithIterator() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultipleInputsAndOutputsWithIterator(Nd4jBackend backend) { MultiDataSetIterator iter = new TestMultiDataSetIterator(1, data); SUT.fit(iter); assertExpectedMeanStd(); } @Test - public void testRevertFeaturesINDArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRevertFeaturesINDArray(Nd4jBackend backend) { SUT.fit(data); MultiDataSet transformed = data.copy(); @@ -99,7 +103,9 @@ public class MultiNormalizerStandardizeTest extends BaseNd4jTest { } @Test - public void testRevertLabelsINDArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRevertLabelsINDArray(Nd4jBackend backend) { SUT.fit(data); MultiDataSet transformed = data.copy(); @@ -115,7 +121,9 @@ public class MultiNormalizerStandardizeTest extends BaseNd4jTest { } @Test - public void testRevertMultiDataSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRevertMultiDataSet(Nd4jBackend backend) { SUT.fit(data); MultiDataSet transformed = data.copy(); @@ -131,13 +139,15 @@ public class MultiNormalizerStandardizeTest extends BaseNd4jTest { } @Test - public void testFullyMaskedData() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFullyMaskedData(Nd4jBackend backend) { MultiDataSetIterator iter = new TestMultiDataSetIterator(1, - new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {1}).reshape(1, 1, 1)}, - new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)}), - new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)}, - new INDArray[] {Nd4j.create(new float[] {4}).reshape(1, 1, 1)}, null, - new INDArray[] {Nd4j.create(new float[] {0}).reshape(1, 1)})); + new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {1}).reshape(1, 1, 1)}, + new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)}), + new MultiDataSet(new INDArray[] {Nd4j.create(new float[] {2}).reshape(1, 1, 1)}, + new INDArray[] {Nd4j.create(new float[] {4}).reshape(1, 1, 1)}, null, + new INDArray[] {Nd4j.create(new float[] {0}).reshape(1, 1)})); SUT.fit(iter); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerMinMaxScalerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerMinMaxScalerTest.java index 5e7cff650..36bf8d76c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerMinMaxScalerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerMinMaxScalerTest.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator; @@ -37,15 +38,14 @@ import org.nd4j.linalg.ops.transforms.Transforms; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class NormalizerMinMaxScalerTest extends BaseNd4jTest { - public NormalizerMinMaxScalerTest(Nd4jBackend backend) { - super(backend); - } +public class NormalizerMinMaxScalerTest extends BaseNd4jTestWithBackends { + @Test - public void testBruteForce() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBruteForce(Nd4jBackend backend) { //X_std = (X - X.min(axis=0)) / (X.max(axis=0) - X.min(axis=0)) //X_scaled = X_std * (max - min) + min // Dataset features are scaled consecutive natural numbers @@ -98,7 +98,9 @@ public class NormalizerMinMaxScalerTest extends BaseNd4jTest { } @Test - public void testRevert() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRevert(Nd4jBackend backend) { double tolerancePerc = 1; // 1% of correct value int nSamples = 500; int nFeatures = 3; @@ -115,7 +117,7 @@ public class NormalizerMinMaxScalerTest extends BaseNd4jTest { myNormalizer.transform(transformed); myNormalizer.revert(transformed); INDArray delta = Transforms.abs(transformed.getFeatures().sub(sampleDataSet.getFeatures())) - .div(sampleDataSet.getFeatures()); + .div(sampleDataSet.getFeatures()); double maxdeltaPerc = delta.max(0, 1).mul(100).getDouble(0); System.out.println("Delta: " + maxdeltaPerc); assertTrue(maxdeltaPerc < tolerancePerc); @@ -123,7 +125,9 @@ public class NormalizerMinMaxScalerTest extends BaseNd4jTest { } @Test - public void testGivenMaxMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGivenMaxMin(Nd4jBackend backend) { double tolerancePerc = 1; // 1% of correct value int nSamples = 500; int nFeatures = 3; @@ -143,14 +147,16 @@ public class NormalizerMinMaxScalerTest extends BaseNd4jTest { myNormalizer.revert(transformed); INDArray delta = Transforms.abs(transformed.getFeatures().sub(sampleDataSet.getFeatures())) - .div(sampleDataSet.getFeatures()); + .div(sampleDataSet.getFeatures()); double maxdeltaPerc = delta.max(0, 1).mul(100).getDouble(0); System.out.println("Delta: " + maxdeltaPerc); assertTrue(maxdeltaPerc < tolerancePerc); } @Test - public void testGivenMaxMinConstant() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGivenMaxMinConstant(Nd4jBackend backend) { double tolerancePerc = 1; // 1% of correct value int nSamples = 500; int nFeatures = 3; @@ -175,7 +181,9 @@ public class NormalizerMinMaxScalerTest extends BaseNd4jTest { } @Test - public void testConstant() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConstant(Nd4jBackend backend) { double tolerancePerc = 0.01; // 0.01% of correct value int nSamples = 500; int nFeatures = 3; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java index 3c88e2256..b095f419b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.dataset; import lombok.Getter; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.dataset.api.preprocessor.AbstractDataSetNormalizer; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.dataset.api.preprocessor.MinMaxStrategy; @@ -41,7 +42,6 @@ import org.nd4j.linalg.dataset.api.preprocessor.stats.DistributionStats; import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats; import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; import java.io.*; import java.util.HashMap; @@ -54,14 +54,11 @@ import static org.junit.jupiter.api.Assertions.assertThrows; /** * @author Ede Meijer */ -@RunWith(Parameterized.class) -public class NormalizerSerializerTest extends BaseNd4jTest { + +public class NormalizerSerializerTest extends BaseNd4jTestWithBackends { private File tmpFile; private NormalizerSerializer SUT; - public NormalizerSerializerTest(Nd4jBackend backend) { - super(backend); - } @BeforeEach public void setUp() throws IOException { @@ -72,6 +69,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testImagePreProcessingScaler() throws Exception { ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(0,1); SUT.write(imagePreProcessingScaler,tmpFile); @@ -81,6 +80,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNormalizerStandardizeNotFitLabels() throws Exception { NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)); @@ -92,6 +93,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNormalizerStandardizeFitLabels() throws Exception { NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1), Nd4j.create(new double[] {4.5, 5.5}).reshape(1, -1), @@ -105,6 +108,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNormalizerMinMaxScalerNotFitLabels() throws Exception { NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9); original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)); @@ -116,6 +121,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNormalizerMinMaxScalerFitLabels() throws Exception { NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9); original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})); @@ -129,6 +136,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultiNormalizerStandardizeNotFitLabels() throws Exception { MultiNormalizerStandardize original = new MultiNormalizerStandardize(); original.setFeatureStats(asList( @@ -144,6 +153,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultiNormalizerStandardizeFitLabels() throws Exception { MultiNormalizerStandardize original = new MultiNormalizerStandardize(); original.setFeatureStats(asList( @@ -166,6 +177,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultiNormalizerMinMaxScalerNotFitLabels() throws Exception { MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9); original.setFeatureStats(asList( @@ -180,6 +193,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultiNormalizerMinMaxScalerFitLabels() throws Exception { MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9); original.setFeatureStats(asList( @@ -200,6 +215,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultiNormalizerHybridEmpty() throws Exception { MultiNormalizerHybrid original = new MultiNormalizerHybrid(); original.setInputStats(new HashMap()); @@ -212,6 +229,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultiNormalizerHybridGlobalStats() throws Exception { MultiNormalizerHybrid original = new MultiNormalizerHybrid().minMaxScaleAllInputs().standardizeAllOutputs(); @@ -233,6 +252,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultiNormalizerHybridGlobalAndSpecificStats() throws Exception { MultiNormalizerHybrid original = new MultiNormalizerHybrid().standardizeAllInputs().minMaxScaleInput(0, -5, 5) .minMaxScaleAllOutputs(-10, 10).standardizeOutput(1); @@ -263,6 +284,8 @@ public class NormalizerSerializerTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCustomNormalizer() throws Exception { MyNormalizer original = new MyNormalizer(42); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeLabelsTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeLabelsTest.java index 1725b9ebe..4b8b36e6d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeLabelsTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeLabelsTest.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator; @@ -35,14 +36,13 @@ import org.nd4j.linalg.ops.transforms.Transforms; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -@RunWith(Parameterized.class) -public class NormalizerStandardizeLabelsTest extends BaseNd4jTest { - public NormalizerStandardizeLabelsTest(Nd4jBackend backend) { - super(backend); - } + +public class NormalizerStandardizeLabelsTest extends BaseNd4jTestWithBackends { @Test - public void testBruteForce() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBruteForce(Nd4jBackend backend) { /* This test creates a dataset where feature values are multiples of consecutive natural numbers The obtained values are compared to the theoretical mean and std dev */ @@ -59,11 +59,11 @@ public class NormalizerStandardizeLabelsTest extends BaseNd4jTest { double meanNaturalNums = (nSamples + 1) / 2.0; INDArray theoreticalMean = - Nd4j.create(new double[] {meanNaturalNums * x, meanNaturalNums * y, meanNaturalNums * z}).reshape(1, -1).castTo(Nd4j.defaultFloatingPointType()); + Nd4j.create(new double[] {meanNaturalNums * x, meanNaturalNums * y, meanNaturalNums * z}).reshape(1, -1).castTo(Nd4j.defaultFloatingPointType()); INDArray theoreticallabelMean = theoreticalMean.dup().getColumns(0); double stdNaturalNums = Math.sqrt((nSamples * nSamples - 1) / 12.0); INDArray theoreticalStd = - Nd4j.create(new double[] {stdNaturalNums * x, stdNaturalNums * y, stdNaturalNums * z}).reshape(1, -1).castTo(Nd4j.defaultFloatingPointType()); + Nd4j.create(new double[] {stdNaturalNums * x, stdNaturalNums * y, stdNaturalNums * z}).reshape(1, -1).castTo(Nd4j.defaultFloatingPointType()); INDArray theoreticallabelStd = theoreticalStd.dup().getColumns(0); NormalizerStandardize myNormalizer = new NormalizerStandardize(); @@ -81,7 +81,7 @@ public class NormalizerStandardizeLabelsTest extends BaseNd4jTest { INDArray stdDelta = Transforms.abs(theoreticalStd.sub(myNormalizer.getStd())); INDArray stdDeltaPerc = stdDelta.div(theoreticalStd).mul(100); INDArray stdlabelDeltaPerc = - Transforms.abs(theoreticallabelStd.sub(myNormalizer.getLabelStd())).div(theoreticallabelStd); + Transforms.abs(theoreticallabelStd.sub(myNormalizer.getLabelStd())).div(theoreticallabelStd); double maxStdDeltaPerc = stdDeltaPerc.max(1).mul(100).getDouble(0); double maxlabelStdDeltaPerc = stdlabelDeltaPerc.max(1).getDouble(0); assertTrue(maxStdDeltaPerc < tolerancePerc); @@ -106,7 +106,9 @@ public class NormalizerStandardizeLabelsTest extends BaseNd4jTest { } @Test - public void testTransform() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTransform(Nd4jBackend backend) { /*Random dataset is generated such that AX + B where X is from a normal distribution with mean 0 and std 1 The mean of above will be B and std A diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java index 25cd555f3..cf7d253ba 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerStandardizeTest.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -35,11 +36,8 @@ import org.nd4j.linalg.ops.transforms.Transforms; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class NormalizerStandardizeTest extends BaseNd4jTest { - public NormalizerStandardizeTest(Nd4jBackend backend) { - super(backend); - } + +public class NormalizerStandardizeTest extends BaseNd4jTestWithBackends { @Override public long getTimeoutMilliseconds() { @@ -47,7 +45,9 @@ public class NormalizerStandardizeTest extends BaseNd4jTest { } @Test - public void testBruteForce() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBruteForce(Nd4jBackend backend) { /* This test creates a dataset where feature values are multiples of consecutive natural numbers The obtained values are compared to the theoretical mean and std dev */ @@ -64,10 +64,10 @@ public class NormalizerStandardizeTest extends BaseNd4jTest { double meanNaturalNums = (nSamples + 1) / 2.0; INDArray theoreticalMean = - Nd4j.create(new double[] {meanNaturalNums * x, meanNaturalNums * y, meanNaturalNums * z}).reshape(1, -1); + Nd4j.create(new double[] {meanNaturalNums * x, meanNaturalNums * y, meanNaturalNums * z}).reshape(1, -1); double stdNaturalNums = Math.sqrt((nSamples * nSamples - 1) / 12.0); INDArray theoreticalStd = - Nd4j.create(new double[] {stdNaturalNums * x, stdNaturalNums * y, stdNaturalNums * z}).reshape(1, -1); + Nd4j.create(new double[] {stdNaturalNums * x, stdNaturalNums * y, stdNaturalNums * z}).reshape(1, -1); NormalizerStandardize myNormalizer = new NormalizerStandardize(); myNormalizer.fit(sampleDataSet); @@ -100,7 +100,9 @@ public class NormalizerStandardizeTest extends BaseNd4jTest { } @Test - public void testTransform() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTransform(Nd4jBackend backend) { /*Random dataset is generated such that AX + B where X is from a normal distribution with mean 0 and std 1 The mean of above will be B and std A @@ -172,7 +174,9 @@ public class NormalizerStandardizeTest extends BaseNd4jTest { } @Test - public void testDifferentBatchSizes() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDifferentBatchSizes(Nd4jBackend backend) { // Create 6x1 matrix of the numbers 1 through 6 INDArray values = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1).transpose(); DataSet dataSet = new DataSet(values, values); @@ -206,7 +210,9 @@ public class NormalizerStandardizeTest extends BaseNd4jTest { } @Test - public void testUnderOverflow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnderOverflow(Nd4jBackend backend) { // This dataset will be basically constant with a small std deviation // And the constant is large. Checking if algorithm can handle double tolerancePerc = 1; //Within 1 % @@ -239,7 +245,9 @@ public class NormalizerStandardizeTest extends BaseNd4jTest { } @Test - public void testRevert() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRevert(Nd4jBackend backend) { double tolerancePerc = 0.01; // 0.01% of correct value int nSamples = 500; int nFeatures = 3; @@ -256,13 +264,15 @@ public class NormalizerStandardizeTest extends BaseNd4jTest { myNormalizer.revert(transformed); //System.out.println(transformed.getFeatures()); INDArray delta = Transforms.abs(transformed.getFeatures().sub(sampleDataSet.getFeatures())) - .div(sampleDataSet.getFeatures()); + .div(sampleDataSet.getFeatures()); double maxdeltaPerc = delta.max(0, 1).mul(100).getDouble(0); assertTrue(maxdeltaPerc < tolerancePerc); } @Test - public void testConstant() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConstant(Nd4jBackend backend) { double tolerancePerc = 10.0; // 10% of correct value int nSamples = 500; int nFeatures = 3; @@ -283,13 +293,13 @@ public class NormalizerStandardizeTest extends BaseNd4jTest { assertFalse(Double.isNaN(sampleDataSet.getFeatures().min(0, 1).getDouble(0))); //Checking to see if transformed values are close enough to zero assertEquals(Transforms.abs(sampleDataSet.getFeatures()).max(0, 1).getDouble(0), 0, - constant * tolerancePerc / 100.0); + constant * tolerancePerc / 100.0); myNormalizer.revert(sampleDataSet); //Checking if we gets nans, because std dev is zero assertFalse(Double.isNaN(sampleDataSet.getFeatures().min(0, 1).getDouble(0))); assertEquals(Transforms.abs(sampleDataSet.getFeatures().sub(featureSet)).min(0, 1).getDouble(0), 0, - constant * tolerancePerc / 100.0); + constant * tolerancePerc / 100.0); } public class genRandomDataSet { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java index 317b8c806..5f0c5a7a7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -49,12 +50,9 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -@RunWith(Parameterized.class) -public class NormalizerTests extends BaseNd4jTest { - public NormalizerTests(Nd4jBackend backend) { - super(backend); - } +public class NormalizerTests extends BaseNd4jTestWithBackends { + private NormalizerStandardize stdScaler; private NormalizerMinMaxScaler minMaxScaler; @@ -78,7 +76,9 @@ public class NormalizerTests extends BaseNd4jTest { } @Test - public void testPreProcessors() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPreProcessors(Nd4jBackend backend) { System.out.println("Running iterator vs non-iterator std scaler.."); double d1 = testItervsDataset(stdScaler); assertTrue( d1 < thresholdPerc,d1 + " < " + thresholdPerc); @@ -111,17 +111,19 @@ public class NormalizerTests extends BaseNd4jTest { @Test - public void testMasking() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMasking(Nd4jBackend backend) { Nd4j.getRandom().setSeed(235); DataNormalization[] normalizers = - new DataNormalization[] {new NormalizerMinMaxScaler(), new NormalizerStandardize()}; + new DataNormalization[] {new NormalizerMinMaxScaler(), new NormalizerStandardize()}; DataNormalization[] normalizersNoMask = - new DataNormalization[] {new NormalizerMinMaxScaler(), new NormalizerStandardize()}; + new DataNormalization[] {new NormalizerMinMaxScaler(), new NormalizerStandardize()}; DataNormalization[] normalizersByRow = - new DataNormalization[] {new NormalizerMinMaxScaler(), new NormalizerStandardize()}; + new DataNormalization[] {new NormalizerMinMaxScaler(), new NormalizerStandardize()}; for (int i = 0; i < normalizers.length; i++) { @@ -139,8 +141,8 @@ public class NormalizerTests extends BaseNd4jTest { INDArray arrPt1 = arr.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.all(), NDArrayIndex.all()).dup(); INDArray arrPt2 = - arr.get(NDArrayIndex.interval(1, 1, true), NDArrayIndex.all(), NDArrayIndex.interval(0, 3)) - .dup(); + arr.get(NDArrayIndex.interval(1, 1, true), NDArrayIndex.all(), NDArrayIndex.interval(0, 3)) + .dup(); INDArray mask = Nd4j.create(new double[][] {{1, 1, 1, 1, 1}, {1, 1, 1, 0, 0}}).castTo(Nd4j.defaultFloatingPointType()); @@ -161,14 +163,14 @@ public class NormalizerTests extends BaseNd4jTest { List toFitRows = new ArrayList<>(); for (int j = 0; j < 5; j++) { INDArray row = arr.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(j, j, true)) - .transpose(); + .transpose(); assertTrue(row.isRowVector()); toFitRows.add(new DataSet(row, row)); } for (int j = 0; j < 3; j++) { INDArray row = arr.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.interval(j, j, true)) - .transpose(); + .transpose(); assertTrue(row.isRowVector()); toFitRows.add(new DataSet(row, row)); } @@ -189,11 +191,11 @@ public class NormalizerTests extends BaseNd4jTest { //Second: ensure time steps post normalization (and post revert) are 0.0 INDArray shouldBe0_1 = ds.getFeatures().get(NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.interval(3, 5)); + NDArrayIndex.interval(3, 5)); INDArray shouldBe0_2 = dsCopy1.getFeatures().get(NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.interval(3, 5)); + NDArrayIndex.interval(3, 5)); INDArray shouldBe0_3 = dsCopy2.getFeatures().get(NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.interval(3, 5)); + NDArrayIndex.interval(3, 5)); INDArray zeros = Nd4j.zeros(shouldBe0_1.shape()); @@ -212,11 +214,11 @@ public class NormalizerTests extends BaseNd4jTest { normFitSubset.revert(dsCopy1); normByRow.revert(dsCopy2); shouldBe0_1 = ds.getFeatures().get(NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.interval(3, 5)); + NDArrayIndex.interval(3, 5)); shouldBe0_2 = dsCopy1.getFeatures().get(NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.interval(3, 5)); + NDArrayIndex.interval(3, 5)); shouldBe0_3 = dsCopy2.getFeatures().get(NDArrayIndex.point(1), NDArrayIndex.all(), - NDArrayIndex.interval(3, 5)); + NDArrayIndex.interval(3, 5)); assertEquals(zeros, shouldBe0_1); assertEquals(zeros, shouldBe0_2); @@ -227,6 +229,8 @@ public class NormalizerTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNormalizerToStringHashCode(){ //https://github.com/eclipse/deeplearning4j/issues/8565 @@ -262,6 +266,8 @@ public class NormalizerTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultiNormalizerToStringHashCode(){ //https://github.com/eclipse/deeplearning4j/issues/8565 diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java index 0c0808a07..ef8e0fc77 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessor3D4DTest.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.dataset; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -43,15 +44,14 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class PreProcessor3D4DTest extends BaseNd4jTest { - public PreProcessor3D4DTest(Nd4jBackend backend) { - super(backend); - } +public class PreProcessor3D4DTest extends BaseNd4jTestWithBackends { + @Test - public void testBruteForce3d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBruteForce3d(Nd4jBackend backend) { NormalizerStandardize myNormalizer = new NormalizerStandardize(); NormalizerMinMaxScaler myMinMaxScaler = new NormalizerMinMaxScaler(); @@ -88,7 +88,9 @@ public class PreProcessor3D4DTest extends BaseNd4jTest { } @Test - public void testBruteForce3dMaskLabels() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBruteForce3dMaskLabels(Nd4jBackend backend) { NormalizerStandardize myNormalizer = new NormalizerStandardize(); myNormalizer.fitLabel(true); @@ -110,7 +112,7 @@ public class PreProcessor3D4DTest extends BaseNd4jTest { DataSet fullDataSetAA = fullDataSetA.copy(); //This should be the same datasets as above without a mask Construct3dDataSet fullDataSetNoMask = - new Construct3dDataSet(featureScale, timeStepsU + timeStepsV, samples, 1); + new Construct3dDataSet(featureScale, timeStepsU + timeStepsV, samples, 1); //preprocessors - label and feature values are the same myNormalizer.fit(fullDataSetA); @@ -146,93 +148,95 @@ public class PreProcessor3D4DTest extends BaseNd4jTest { } @Test - public void testStdX() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStdX(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {11.10, 22.20, 33.30, 44.40, 55.50, 66.60, 77.70, 88.80, 99.90, - 111.00, 122.10, 133.20, 144.30, 155.40, 166.50, 177.60, 188.70, 199.80, 210.90, 222.00, 233.10, - 244.20, 255.30, 266.40, 277.50, 288.60, 299.70, 310.80, 321.90, 333.00, 344.10, 355.20, 366.30, - 377.40, 388.50, 399.60, 410.70, 421.80, 432.90, 444.00, 455.10, 466.20, 477.30, 488.40, 499.50, - 510.60, 521.70, 532.80, 543.90, 555.00, 566.10, 577.20, 588.30, 599.40, 610.50, 621.60, 632.70, - 643.80, 654.90, 666.00, 677.10, 688.20, 699.30, 710.40, 721.50, 732.60, 743.70, 754.80, 765.90, - 777.00, 788.10, 799.20, 810.30, 821.40, 832.50, 843.60, 854.70, 865.80, 876.90, 888.00, 899.10, - 910.20, 921.30, 932.40, 943.50, 954.60, 965.70, 976.80, 987.90, 999.00, 1, 010.10, 1, 021.20, 1, - 032.30, 1, 043.40, 1, 054.50, 1, 065.60, 1, 076.70, 1, 087.80, 1, 098.90, 1, 110.00, 1, 121.10, - 1, 132.20, 1, 143.30, 1, 154.40, 1, 165.50, 1, 176.60, 1, 187.70, 1, 198.80, 1, 209.90, 1, - 221.00, 1, 232.10, 1, 243.20, 1, 254.30, 1, 265.40, 1, 276.50, 1, 287.60, 1, 298.70, 1, 309.80, - 1, 320.90, 1, 332.00, 1, 343.10, 1, 354.20, 1, 365.30, 1, 376.40, 1, 387.50, 1, 398.60, 1, - 409.70, 1, 420.80, 1, 431.90, 1, 443.00, 1, 454.10, 1, 465.20, 1, 476.30, 1, 487.40, 1, 498.50, - 1, 509.60, 1, 520.70, 1, 531.80, 1, 542.90, 1, 554.00, 1, 565.10, 1, 576.20, 1, 587.30, 1, - 598.40, 1, 609.50, 1, 620.60, 1, 631.70, 1, 642.80, 1, 653.90, 1, 665.00, 2.10, 4.20, 6.30, - 8.40, 10.50, 12.60, 14.70, 16.80, 18.90, 21.00, 23.10, 25.20, 27.30, 29.40, 31.50, 33.60, 35.70, - 37.80, 39.90, 42.00, 44.10, 46.20, 48.30, 50.40, 52.50, 54.60, 56.70, 58.80, 60.90, 63.00, - 65.10, 67.20, 69.30, 71.40, 73.50, 75.60, 77.70, 79.80, 81.90, 84.00, 86.10, 88.20, 90.30, - 92.40, 94.50, 96.60, 98.70, 100.80, 102.90, 105.00, 107.10, 109.20, 111.30, 113.40, 115.50, - 117.60, 119.70, 121.80, 123.90, 126.00, 128.10, 130.20, 132.30, 134.40, 136.50, 138.60, 140.70, - 142.80, 144.90, 147.00, 149.10, 151.20, 153.30, 155.40, 157.50, 159.60, 161.70, 163.80, 165.90, - 168.00, 170.10, 172.20, 174.30, 176.40, 178.50, 180.60, 182.70, 184.80, 186.90, 189.00, 191.10, - 193.20, 195.30, 197.40, 199.50, 201.60, 203.70, 205.80, 207.90, 210.00, 212.10, 214.20, 216.30, - 218.40, 220.50, 222.60, 224.70, 226.80, 228.90, 231.00, 233.10, 235.20, 237.30, 239.40, 241.50, - 243.60, 245.70, 247.80, 249.90, 252.00, 254.10, 256.20, 258.30, 260.40, 262.50, 264.60, 266.70, - 268.80, 270.90, 273.00, 275.10, 277.20, 279.30, 281.40, 283.50, 285.60, 287.70, 289.80, 291.90, - 294.00, 296.10, 298.20, 300.30, 302.40, 304.50, 306.60, 308.70, 310.80, 312.90, 315.00, 10.00, - 20.00, 30.00, 40.00, 50.00, 60.00, 70.00, 80.00, 90.00, 100.00, 110.00, 120.00, 130.00, 140.00, - 150.00, 160.00, 170.00, 180.00, 190.00, 200.00, 210.00, 220.00, 230.00, 240.00, 250.00, 260.00, - 270.00, 280.00, 290.00, 300.00, 310.00, 320.00, 330.00, 340.00, 350.00, 360.00, 370.00, 380.00, - 390.00, 400.00, 410.00, 420.00, 430.00, 440.00, 450.00, 460.00, 470.00, 480.00, 490.00, 500.00, - 510.00, 520.00, 530.00, 540.00, 550.00, 560.00, 570.00, 580.00, 590.00, 600.00, 610.00, 620.00, - 630.00, 640.00, 650.00, 660.00, 670.00, 680.00, 690.00, 700.00, 710.00, 720.00, 730.00, 740.00, - 750.00, 760.00, 770.00, 780.00, 790.00, 800.00, 810.00, 820.00, 830.00, 840.00, 850.00, 860.00, - 870.00, 880.00, 890.00, 900.00, 910.00, 920.00, 930.00, 940.00, 950.00, 960.00, 970.00, 980.00, - 990.00, 1, 000.00, 1, 010.00, 1, 020.00, 1, 030.00, 1, 040.00, 1, 050.00, 1, 060.00, 1, 070.00, - 1, 080.00, 1, 090.00, 1, 100.00, 1, 110.00, 1, 120.00, 1, 130.00, 1, 140.00, 1, 150.00, 1, - 160.00, 1, 170.00, 1, 180.00, 1, 190.00, 1, 200.00, 1, 210.00, 1, 220.00, 1, 230.00, 1, 240.00, - 1, 250.00, 1, 260.00, 1, 270.00, 1, 280.00, 1, 290.00, 1, 300.00, 1, 310.00, 1, 320.00, 1, - 330.00, 1, 340.00, 1, 350.00, 1, 360.00, 1, 370.00, 1, 380.00, 1, 390.00, 1, 400.00, 1, 410.00, - 1, 420.00, 1, 430.00, 1, 440.00, 1, 450.00, 1, 460.00, 1, 470.00, 1, 480.00, 1, 490.00, 1, - 500.00, 99.00, 198.00, 297.00, 396.00, 495.00, 594.00, 693.00, 792.00, 891.00, 990.00, 1, - 089.00, 1, 188.00, 1, 287.00, 1, 386.00, 1, 485.00, 1, 584.00, 1, 683.00, 1, 782.00, 1, 881.00, - 1, 980.00, 2, 079.00, 2, 178.00, 2, 277.00, 2, 376.00, 2, 475.00, 2, 574.00, 2, 673.00, 2, - 772.00, 2, 871.00, 2, 970.00, 3, 069.00, 3, 168.00, 3, 267.00, 3, 366.00, 3, 465.00, 3, 564.00, - 3, 663.00, 3, 762.00, 3, 861.00, 3, 960.00, 4, 059.00, 4, 158.00, 4, 257.00, 4, 356.00, 4, - 455.00, 4, 554.00, 4, 653.00, 4, 752.00, 4, 851.00, 4, 950.00, 5, 049.00, 5, 148.00, 5, 247.00, - 5, 346.00, 5, 445.00, 5, 544.00, 5, 643.00, 5, 742.00, 5, 841.00, 5, 940.00, 6, 039.00, 6, - 138.00, 6, 237.00, 6, 336.00, 6, 435.00, 6, 534.00, 6, 633.00, 6, 732.00, 6, 831.00, 6, 930.00, - 7, 029.00, 7, 128.00, 7, 227.00, 7, 326.00, 7, 425.00, 7, 524.00, 7, 623.00, 7, 722.00, 7, - 821.00, 7, 920.00, 8, 019.00, 8, 118.00, 8, 217.00, 8, 316.00, 8, 415.00, 8, 514.00, 8, 613.00, - 8, 712.00, 8, 811.00, 8, 910.00, 9, 009.00, 9, 108.00, 9, 207.00, 9, 306.00, 9, 405.00, 9, - 504.00, 9, 603.00, 9, 702.00, 9, 801.00, 9, 900.00, 9, 999.00, 10, 098.00, 10, 197.00, 10, - 296.00, 10, 395.00, 10, 494.00, 10, 593.00, 10, 692.00, 10, 791.00, 10, 890.00, 10, 989.00, 11, - 088.00, 11, 187.00, 11, 286.00, 11, 385.00, 11, 484.00, 11, 583.00, 11, 682.00, 11, 781.00, 11, - 880.00, 11, 979.00, 12, 078.00, 12, 177.00, 12, 276.00, 12, 375.00, 12, 474.00, 12, 573.00, 12, - 672.00, 12, 771.00, 12, 870.00, 12, 969.00, 13, 068.00, 13, 167.00, 13, 266.00, 13, 365.00, 13, - 464.00, 13, 563.00, 13, 662.00, 13, 761.00, 13, 860.00, 13, 959.00, 14, 058.00, 14, 157.00, 14, - 256.00, 14, 355.00, 14, 454.00, 14, 553.00, 14, 652.00, 14, 751.00, 14, 850.00, 7.16, 14.31, - 21.47, 28.62, 35.78, 42.94, 50.09, 57.25, 64.40, 71.56, 78.72, 85.87, 93.03, 100.18, 107.34, - 114.50, 121.65, 128.81, 135.96, 143.12, 150.28, 157.43, 164.59, 171.74, 178.90, 186.06, 193.21, - 200.37, 207.52, 214.68, 221.84, 228.99, 236.15, 243.30, 250.46, 257.62, 264.77, 271.93, 279.08, - 286.24, 293.40, 300.55, 307.71, 314.86, 322.02, 329.18, 336.33, 343.49, 350.64, 357.80, 364.96, - 372.11, 379.27, 386.42, 393.58, 400.74, 407.89, 415.05, 422.20, 429.36, 436.52, 443.67, 450.83, - 457.98, 465.14, 472.30, 479.45, 486.61, 493.76, 500.92, 508.08, 515.23, 522.39, 529.54, 536.70, - 543.86, 551.01, 558.17, 565.32, 572.48, 579.64, 586.79, 593.95, 601.10, 608.26, 615.42, 622.57, - 629.73, 636.88, 644.04, 651.20, 658.35, 665.51, 672.66, 679.82, 686.98, 694.13, 701.29, 708.44, - 715.60, 722.76, 729.91, 737.07, 744.22, 751.38, 758.54, 765.69, 772.85, 780.00, 787.16, 794.32, - 801.47, 808.63, 815.78, 822.94, 830.10, 837.25, 844.41, 851.56, 858.72, 865.88, 873.03, 880.19, - 887.34, 894.50, 901.66, 908.81, 915.97, 923.12, 930.28, 937.44, 944.59, 951.75, 958.90, 966.06, - 973.22, 980.37, 987.53, 994.68, 1, 001.84, 1, 009.00, 1, 016.15, 1, 023.31, 1, 030.46, 1, - 037.62, 1, 044.78, 1, 051.93, 1, 059.09, 1, 066.24, 1, 073.40, 9.00, 18.00, 27.00, 36.00, 45.00, - 54.00, 63.00, 72.00, 81.00, 90.00, 99.00, 108.00, 117.00, 126.00, 135.00, 144.00, 153.00, - 162.00, 171.00, 180.00, 189.00, 198.00, 207.00, 216.00, 225.00, 234.00, 243.00, 252.00, 261.00, - 270.00, 279.00, 288.00, 297.00, 306.00, 315.00, 324.00, 333.00, 342.00, 351.00, 360.00, 369.00, - 378.00, 387.00, 396.00, 405.00, 414.00, 423.00, 432.00, 441.00, 450.00, 459.00, 468.00, 477.00, - 486.00, 495.00, 504.00, 513.00, 522.00, 531.00, 540.00, 549.00, 558.00, 567.00, 576.00, 585.00, - 594.00, 603.00, 612.00, 621.00, 630.00, 639.00, 648.00, 657.00, 666.00, 675.00, 684.00, 693.00, - 702.00, 711.00, 720.00, 729.00, 738.00, 747.00, 756.00, 765.00, 774.00, 783.00, 792.00, 801.00, - 810.00, 819.00, 828.00, 837.00, 846.00, 855.00, 864.00, 873.00, 882.00, 891.00, 900.00, 909.00, - 918.00, 927.00, 936.00, 945.00, 954.00, 963.00, 972.00, 981.00, 990.00, 999.00, 1, 008.00, 1, - 017.00, 1, 026.00, 1, 035.00, 1, 044.00, 1, 053.00, 1, 062.00, 1, 071.00, 1, 080.00, 1, 089.00, - 1, 098.00, 1, 107.00, 1, 116.00, 1, 125.00, 1, 134.00, 1, 143.00, 1, 152.00, 1, 161.00, 1, - 170.00, 1, 179.00, 1, 188.00, 1, 197.00, 1, 206.00, 1, 215.00, 1, 224.00, 1, 233.00, 1, 242.00, - 1, 251.00, 1, 260.00, 1, 269.00, 1, 278.00, 1, 287.00, 1, 296.00, 1, 305.00, 1, 314.00, 1, - 323.00, 1, 332.00, 1, 341.00, 1, 350.00}).reshape(1, -1); + 111.00, 122.10, 133.20, 144.30, 155.40, 166.50, 177.60, 188.70, 199.80, 210.90, 222.00, 233.10, + 244.20, 255.30, 266.40, 277.50, 288.60, 299.70, 310.80, 321.90, 333.00, 344.10, 355.20, 366.30, + 377.40, 388.50, 399.60, 410.70, 421.80, 432.90, 444.00, 455.10, 466.20, 477.30, 488.40, 499.50, + 510.60, 521.70, 532.80, 543.90, 555.00, 566.10, 577.20, 588.30, 599.40, 610.50, 621.60, 632.70, + 643.80, 654.90, 666.00, 677.10, 688.20, 699.30, 710.40, 721.50, 732.60, 743.70, 754.80, 765.90, + 777.00, 788.10, 799.20, 810.30, 821.40, 832.50, 843.60, 854.70, 865.80, 876.90, 888.00, 899.10, + 910.20, 921.30, 932.40, 943.50, 954.60, 965.70, 976.80, 987.90, 999.00, 1, 010.10, 1, 021.20, 1, + 032.30, 1, 043.40, 1, 054.50, 1, 065.60, 1, 076.70, 1, 087.80, 1, 098.90, 1, 110.00, 1, 121.10, + 1, 132.20, 1, 143.30, 1, 154.40, 1, 165.50, 1, 176.60, 1, 187.70, 1, 198.80, 1, 209.90, 1, + 221.00, 1, 232.10, 1, 243.20, 1, 254.30, 1, 265.40, 1, 276.50, 1, 287.60, 1, 298.70, 1, 309.80, + 1, 320.90, 1, 332.00, 1, 343.10, 1, 354.20, 1, 365.30, 1, 376.40, 1, 387.50, 1, 398.60, 1, + 409.70, 1, 420.80, 1, 431.90, 1, 443.00, 1, 454.10, 1, 465.20, 1, 476.30, 1, 487.40, 1, 498.50, + 1, 509.60, 1, 520.70, 1, 531.80, 1, 542.90, 1, 554.00, 1, 565.10, 1, 576.20, 1, 587.30, 1, + 598.40, 1, 609.50, 1, 620.60, 1, 631.70, 1, 642.80, 1, 653.90, 1, 665.00, 2.10, 4.20, 6.30, + 8.40, 10.50, 12.60, 14.70, 16.80, 18.90, 21.00, 23.10, 25.20, 27.30, 29.40, 31.50, 33.60, 35.70, + 37.80, 39.90, 42.00, 44.10, 46.20, 48.30, 50.40, 52.50, 54.60, 56.70, 58.80, 60.90, 63.00, + 65.10, 67.20, 69.30, 71.40, 73.50, 75.60, 77.70, 79.80, 81.90, 84.00, 86.10, 88.20, 90.30, + 92.40, 94.50, 96.60, 98.70, 100.80, 102.90, 105.00, 107.10, 109.20, 111.30, 113.40, 115.50, + 117.60, 119.70, 121.80, 123.90, 126.00, 128.10, 130.20, 132.30, 134.40, 136.50, 138.60, 140.70, + 142.80, 144.90, 147.00, 149.10, 151.20, 153.30, 155.40, 157.50, 159.60, 161.70, 163.80, 165.90, + 168.00, 170.10, 172.20, 174.30, 176.40, 178.50, 180.60, 182.70, 184.80, 186.90, 189.00, 191.10, + 193.20, 195.30, 197.40, 199.50, 201.60, 203.70, 205.80, 207.90, 210.00, 212.10, 214.20, 216.30, + 218.40, 220.50, 222.60, 224.70, 226.80, 228.90, 231.00, 233.10, 235.20, 237.30, 239.40, 241.50, + 243.60, 245.70, 247.80, 249.90, 252.00, 254.10, 256.20, 258.30, 260.40, 262.50, 264.60, 266.70, + 268.80, 270.90, 273.00, 275.10, 277.20, 279.30, 281.40, 283.50, 285.60, 287.70, 289.80, 291.90, + 294.00, 296.10, 298.20, 300.30, 302.40, 304.50, 306.60, 308.70, 310.80, 312.90, 315.00, 10.00, + 20.00, 30.00, 40.00, 50.00, 60.00, 70.00, 80.00, 90.00, 100.00, 110.00, 120.00, 130.00, 140.00, + 150.00, 160.00, 170.00, 180.00, 190.00, 200.00, 210.00, 220.00, 230.00, 240.00, 250.00, 260.00, + 270.00, 280.00, 290.00, 300.00, 310.00, 320.00, 330.00, 340.00, 350.00, 360.00, 370.00, 380.00, + 390.00, 400.00, 410.00, 420.00, 430.00, 440.00, 450.00, 460.00, 470.00, 480.00, 490.00, 500.00, + 510.00, 520.00, 530.00, 540.00, 550.00, 560.00, 570.00, 580.00, 590.00, 600.00, 610.00, 620.00, + 630.00, 640.00, 650.00, 660.00, 670.00, 680.00, 690.00, 700.00, 710.00, 720.00, 730.00, 740.00, + 750.00, 760.00, 770.00, 780.00, 790.00, 800.00, 810.00, 820.00, 830.00, 840.00, 850.00, 860.00, + 870.00, 880.00, 890.00, 900.00, 910.00, 920.00, 930.00, 940.00, 950.00, 960.00, 970.00, 980.00, + 990.00, 1, 000.00, 1, 010.00, 1, 020.00, 1, 030.00, 1, 040.00, 1, 050.00, 1, 060.00, 1, 070.00, + 1, 080.00, 1, 090.00, 1, 100.00, 1, 110.00, 1, 120.00, 1, 130.00, 1, 140.00, 1, 150.00, 1, + 160.00, 1, 170.00, 1, 180.00, 1, 190.00, 1, 200.00, 1, 210.00, 1, 220.00, 1, 230.00, 1, 240.00, + 1, 250.00, 1, 260.00, 1, 270.00, 1, 280.00, 1, 290.00, 1, 300.00, 1, 310.00, 1, 320.00, 1, + 330.00, 1, 340.00, 1, 350.00, 1, 360.00, 1, 370.00, 1, 380.00, 1, 390.00, 1, 400.00, 1, 410.00, + 1, 420.00, 1, 430.00, 1, 440.00, 1, 450.00, 1, 460.00, 1, 470.00, 1, 480.00, 1, 490.00, 1, + 500.00, 99.00, 198.00, 297.00, 396.00, 495.00, 594.00, 693.00, 792.00, 891.00, 990.00, 1, + 089.00, 1, 188.00, 1, 287.00, 1, 386.00, 1, 485.00, 1, 584.00, 1, 683.00, 1, 782.00, 1, 881.00, + 1, 980.00, 2, 079.00, 2, 178.00, 2, 277.00, 2, 376.00, 2, 475.00, 2, 574.00, 2, 673.00, 2, + 772.00, 2, 871.00, 2, 970.00, 3, 069.00, 3, 168.00, 3, 267.00, 3, 366.00, 3, 465.00, 3, 564.00, + 3, 663.00, 3, 762.00, 3, 861.00, 3, 960.00, 4, 059.00, 4, 158.00, 4, 257.00, 4, 356.00, 4, + 455.00, 4, 554.00, 4, 653.00, 4, 752.00, 4, 851.00, 4, 950.00, 5, 049.00, 5, 148.00, 5, 247.00, + 5, 346.00, 5, 445.00, 5, 544.00, 5, 643.00, 5, 742.00, 5, 841.00, 5, 940.00, 6, 039.00, 6, + 138.00, 6, 237.00, 6, 336.00, 6, 435.00, 6, 534.00, 6, 633.00, 6, 732.00, 6, 831.00, 6, 930.00, + 7, 029.00, 7, 128.00, 7, 227.00, 7, 326.00, 7, 425.00, 7, 524.00, 7, 623.00, 7, 722.00, 7, + 821.00, 7, 920.00, 8, 019.00, 8, 118.00, 8, 217.00, 8, 316.00, 8, 415.00, 8, 514.00, 8, 613.00, + 8, 712.00, 8, 811.00, 8, 910.00, 9, 009.00, 9, 108.00, 9, 207.00, 9, 306.00, 9, 405.00, 9, + 504.00, 9, 603.00, 9, 702.00, 9, 801.00, 9, 900.00, 9, 999.00, 10, 098.00, 10, 197.00, 10, + 296.00, 10, 395.00, 10, 494.00, 10, 593.00, 10, 692.00, 10, 791.00, 10, 890.00, 10, 989.00, 11, + 088.00, 11, 187.00, 11, 286.00, 11, 385.00, 11, 484.00, 11, 583.00, 11, 682.00, 11, 781.00, 11, + 880.00, 11, 979.00, 12, 078.00, 12, 177.00, 12, 276.00, 12, 375.00, 12, 474.00, 12, 573.00, 12, + 672.00, 12, 771.00, 12, 870.00, 12, 969.00, 13, 068.00, 13, 167.00, 13, 266.00, 13, 365.00, 13, + 464.00, 13, 563.00, 13, 662.00, 13, 761.00, 13, 860.00, 13, 959.00, 14, 058.00, 14, 157.00, 14, + 256.00, 14, 355.00, 14, 454.00, 14, 553.00, 14, 652.00, 14, 751.00, 14, 850.00, 7.16, 14.31, + 21.47, 28.62, 35.78, 42.94, 50.09, 57.25, 64.40, 71.56, 78.72, 85.87, 93.03, 100.18, 107.34, + 114.50, 121.65, 128.81, 135.96, 143.12, 150.28, 157.43, 164.59, 171.74, 178.90, 186.06, 193.21, + 200.37, 207.52, 214.68, 221.84, 228.99, 236.15, 243.30, 250.46, 257.62, 264.77, 271.93, 279.08, + 286.24, 293.40, 300.55, 307.71, 314.86, 322.02, 329.18, 336.33, 343.49, 350.64, 357.80, 364.96, + 372.11, 379.27, 386.42, 393.58, 400.74, 407.89, 415.05, 422.20, 429.36, 436.52, 443.67, 450.83, + 457.98, 465.14, 472.30, 479.45, 486.61, 493.76, 500.92, 508.08, 515.23, 522.39, 529.54, 536.70, + 543.86, 551.01, 558.17, 565.32, 572.48, 579.64, 586.79, 593.95, 601.10, 608.26, 615.42, 622.57, + 629.73, 636.88, 644.04, 651.20, 658.35, 665.51, 672.66, 679.82, 686.98, 694.13, 701.29, 708.44, + 715.60, 722.76, 729.91, 737.07, 744.22, 751.38, 758.54, 765.69, 772.85, 780.00, 787.16, 794.32, + 801.47, 808.63, 815.78, 822.94, 830.10, 837.25, 844.41, 851.56, 858.72, 865.88, 873.03, 880.19, + 887.34, 894.50, 901.66, 908.81, 915.97, 923.12, 930.28, 937.44, 944.59, 951.75, 958.90, 966.06, + 973.22, 980.37, 987.53, 994.68, 1, 001.84, 1, 009.00, 1, 016.15, 1, 023.31, 1, 030.46, 1, + 037.62, 1, 044.78, 1, 051.93, 1, 059.09, 1, 066.24, 1, 073.40, 9.00, 18.00, 27.00, 36.00, 45.00, + 54.00, 63.00, 72.00, 81.00, 90.00, 99.00, 108.00, 117.00, 126.00, 135.00, 144.00, 153.00, + 162.00, 171.00, 180.00, 189.00, 198.00, 207.00, 216.00, 225.00, 234.00, 243.00, 252.00, 261.00, + 270.00, 279.00, 288.00, 297.00, 306.00, 315.00, 324.00, 333.00, 342.00, 351.00, 360.00, 369.00, + 378.00, 387.00, 396.00, 405.00, 414.00, 423.00, 432.00, 441.00, 450.00, 459.00, 468.00, 477.00, + 486.00, 495.00, 504.00, 513.00, 522.00, 531.00, 540.00, 549.00, 558.00, 567.00, 576.00, 585.00, + 594.00, 603.00, 612.00, 621.00, 630.00, 639.00, 648.00, 657.00, 666.00, 675.00, 684.00, 693.00, + 702.00, 711.00, 720.00, 729.00, 738.00, 747.00, 756.00, 765.00, 774.00, 783.00, 792.00, 801.00, + 810.00, 819.00, 828.00, 837.00, 846.00, 855.00, 864.00, 873.00, 882.00, 891.00, 900.00, 909.00, + 918.00, 927.00, 936.00, 945.00, 954.00, 963.00, 972.00, 981.00, 990.00, 999.00, 1, 008.00, 1, + 017.00, 1, 026.00, 1, 035.00, 1, 044.00, 1, 053.00, 1, 062.00, 1, 071.00, 1, 080.00, 1, 089.00, + 1, 098.00, 1, 107.00, 1, 116.00, 1, 125.00, 1, 134.00, 1, 143.00, 1, 152.00, 1, 161.00, 1, + 170.00, 1, 179.00, 1, 188.00, 1, 197.00, 1, 206.00, 1, 215.00, 1, 224.00, 1, 233.00, 1, 242.00, + 1, 251.00, 1, 260.00, 1, 269.00, 1, 278.00, 1, 287.00, 1, 296.00, 1, 305.00, 1, 314.00, 1, + 323.00, 1, 332.00, 1, 341.00, 1, 350.00}).reshape(1, -1); float templateStd = array.std(1).getFloat(0); @@ -240,7 +244,9 @@ public class PreProcessor3D4DTest extends BaseNd4jTest { } @Test - public void testBruteForce4d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBruteForce4d(Nd4jBackend backend) { Construct4dDataSet imageDataSet = new Construct4dDataSet(10, 5, 10, 15); NormalizerStandardize myNormalizer = new NormalizerStandardize(); @@ -265,12 +271,16 @@ public class PreProcessor3D4DTest extends BaseNd4jTest { } @Test - public void test3dRevertStandardize() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test3dRevertStandardize(Nd4jBackend backend) { test3dRevert(new NormalizerStandardize()); } @Test - public void test3dRevertNormalize() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test3dRevertNormalize(Nd4jBackend backend) { test3dRevert(new NormalizerMinMaxScaler()); } @@ -290,7 +300,9 @@ public class PreProcessor3D4DTest extends BaseNd4jTest { } @Test - public void test3dNinMaxScaling() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test3dNinMaxScaling(Nd4jBackend backend) { INDArray values = Nd4j.linspace(-10, 10, 100).reshape(5, 2, 10); DataSet data = new DataSet(values, values); @@ -379,9 +391,9 @@ public class PreProcessor3D4DTest extends BaseNd4jTest { INDArray allImages = Nd4j.rand(new int[] {nExamples, nChannels, height, width}); allImages.get(NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).muli(100) - .addi(200); + .addi(200); allImages.get(NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all()).muli(0.01) - .subi(10); + .subi(10); INDArray labels = Nd4j.linspace(1, nChannels, nChannels).reshape('c', nChannels, 1); sampleDataSet = new DataSet(allImages, labels); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessorTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessorTests.java index 7fc3363bb..e6e594dba 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessorTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/PreProcessorTests.java @@ -21,7 +21,9 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; @@ -32,14 +34,13 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import static org.junit.jupiter.api.Assertions.*; -public class PreProcessorTests extends BaseNd4jTest { +public class PreProcessorTests extends BaseNd4jTestWithBackends { - public PreProcessorTests(Nd4jBackend backend) { - super(backend); - } @Test - public void testLabelLastTimeStepPreProcessor(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLabelLastTimeStepPreProcessor(Nd4jBackend backend){ INDArray f = Nd4j.rand(DataType.FLOAT, 3, 5, 8); INDArray l = Nd4j.rand(DataType.FLOAT, 3, 4, 8); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/StandardScalerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/StandardScalerTest.java index 930b33763..57a393862 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/StandardScalerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/StandardScalerTest.java @@ -22,23 +22,23 @@ package org.nd4j.linalg.dataset; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.StandardScaler; import org.nd4j.linalg.factory.Nd4jBackend; -@RunWith(Parameterized.class) -public class StandardScalerTest extends BaseNd4jTest { - public StandardScalerTest(Nd4jBackend backend) { - super(backend); - } + +public class StandardScalerTest extends BaseNd4jTestWithBackends { @Disabled - @Test - public void testScale() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScale(Nd4jBackend backend) { StandardScaler scaler = new StandardScaler(); DataSetIterator iter = new IrisDataSetIterator(10, 150); scaler.fit(iter); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java index efc3c06f0..a6148ad0e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java @@ -21,7 +21,9 @@ package org.nd4j.linalg.dataset.api.preprocessor; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.factory.Nd4j; @@ -29,11 +31,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.*; -public class CompositeDataSetPreProcessorTest extends BaseNd4jTest { +public class CompositeDataSetPreProcessorTest extends BaseNd4jTestWithBackends { - public CompositeDataSetPreProcessorTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -41,7 +40,9 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTest { } @Test() - public void when_preConditionsIsNull_expect_NullPointerException() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_preConditionsIsNull_expect_NullPointerException(Nd4jBackend backend) { assertThrows(NullPointerException.class,() -> { // Assemble CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(); @@ -54,7 +55,9 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTest { } @Test - public void when_dataSetIsEmpty_expect_emptyDataSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_dataSetIsEmpty_expect_emptyDataSet(Nd4jBackend backend) { // Assemble CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(); DataSet ds = new DataSet(null, null); @@ -67,7 +70,9 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTest { } @Test - public void when_notStoppingOnEmptyDataSet_expect_allPreProcessorsCalled() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_notStoppingOnEmptyDataSet_expect_allPreProcessorsCalled(Nd4jBackend backend) { // Assemble TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(true); TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(true); @@ -83,7 +88,9 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTest { } @Test - public void when_stoppingOnEmptyDataSetAndFirstPreProcessorClearDS_expect_firstPreProcessorsCalled() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_stoppingOnEmptyDataSetAndFirstPreProcessorClearDS_expect_firstPreProcessorsCalled(Nd4jBackend backend) { // Assemble TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(true); TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(true); @@ -99,7 +106,9 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTest { } @Test - public void when_stoppingOnEmptyDataSet_expect_firstPreProcessorsCalled() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_stoppingOnEmptyDataSet_expect_firstPreProcessorsCalled(Nd4jBackend backend) { // Assemble TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(false); TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java index ec353d2d9..6c7e769a8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java @@ -21,7 +21,9 @@ package org.nd4j.linalg.dataset.api.preprocessor; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; @@ -31,11 +33,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.*; -public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { +public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBackends { - public CropAndResizeDataSetPreProcessorTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -43,7 +42,9 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { } @Test() - public void when_originalHeightIsZero_expect_IllegalArgumentException() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_originalHeightIsZero_expect_IllegalArgumentException(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(0, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); @@ -51,7 +52,9 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { } @Test() - public void when_originalWidthIsZero_expect_IllegalArgumentException() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_originalWidthIsZero_expect_IllegalArgumentException(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 0, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); @@ -59,7 +62,9 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { } @Test() - public void when_yStartIsNegative_expect_IllegalArgumentException() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_yStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, -1, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); @@ -67,7 +72,9 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { } @Test() - public void when_xStartIsNegative_expect_IllegalArgumentException() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_xStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, -1, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); @@ -75,7 +82,9 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { } @Test() - public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 0, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); @@ -83,7 +92,9 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { } @Test() - public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 0, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); @@ -91,7 +102,9 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { } @Test() - public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 0, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); @@ -99,7 +112,9 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { } @Test() - public void when_dataSetIsNull_expect_NullPointerException() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) { // Assemble assertThrows(NullPointerException.class,() -> { CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); @@ -111,7 +126,9 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { } @Test - public void when_dataSetIsEmpty_expect_emptyDataSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_dataSetIsEmpty_expect_emptyDataSet(Nd4jBackend backend) { // Assemble CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); DataSet ds = new DataSet(null, null); @@ -124,7 +141,9 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTest { } @Test - public void when_dataSetIs15wx10h_expect_3wx4hDataSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_dataSetIs15wx10h_expect_3wx4hDataSet(Nd4jBackend backend) { // Assemble int numChannels = 3; int height = 10; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategyTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategyTest.java index 5a22457c9..c8bfcb593 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategyTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategyTest.java @@ -21,26 +21,25 @@ package org.nd4j.linalg.dataset.api.preprocessor; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Ede Meijer */ -@RunWith(Parameterized.class) -public class MinMaxStrategyTest extends BaseNd4jTest { - public MinMaxStrategyTest(Nd4jBackend backend) { - super(backend); - } + +public class MinMaxStrategyTest extends BaseNd4jTestWithBackends { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRowVector() { MinMaxStrategy SUT = new MinMaxStrategy(0, 1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java index 81881fc42..9485f5bcd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java @@ -20,8 +20,9 @@ package org.nd4j.linalg.dataset.api.preprocessor; -import org.nd4j.linalg.BaseNd4jTest; -import org.nd4j.linalg.dataset.api.preprocessor.PermuteDataSetPreProcessor; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -30,11 +31,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.*; -public class PermuteDataSetPreProcessorTest extends BaseNd4jTest { +public class PermuteDataSetPreProcessorTest extends BaseNd4jTestWithBackends { - public PermuteDataSetPreProcessorTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -42,7 +40,7 @@ public class PermuteDataSetPreProcessorTest extends BaseNd4jTest { } @Test() - public void when_dataSetIsNull_expect_NullPointerException() { + public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) { assertThrows(NullPointerException.class,() -> { // Assemble PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NCHWtoNHWC); @@ -54,7 +52,9 @@ public class PermuteDataSetPreProcessorTest extends BaseNd4jTest { } @Test - public void when_emptyDatasetInInputdataSetIsNCHW_expect_emptyDataSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_emptyDatasetInInputdataSetIsNCHW_expect_emptyDataSet(Nd4jBackend backend) { // Assemble PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NCHWtoNHWC); DataSet ds = new DataSet(null, null); @@ -67,7 +67,9 @@ public class PermuteDataSetPreProcessorTest extends BaseNd4jTest { } @Test - public void when_dataSetIsNCHW_expect_dataSetTransformedToNHWC() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_dataSetIsNCHW_expect_dataSetTransformedToNHWC(Nd4jBackend backend) { // Assemble int numChannels = 3; int height = 5; @@ -112,7 +114,9 @@ public class PermuteDataSetPreProcessorTest extends BaseNd4jTest { } @Test - public void when_dataSetIsNHWC_expect_dataSetTransformedToNCHW() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_dataSetIsNHWC_expect_dataSetTransformedToNCHW(Nd4jBackend backend) { // Assemble int numChannels = 3; int height = 5; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java index 305c87855..1a2be9f7c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java @@ -21,7 +21,9 @@ package org.nd4j.linalg.dataset.api.preprocessor; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -29,11 +31,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.*; -public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTest { +public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTestWithBackends { - public RGBtoGrayscaleDataSetPreProcessorTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -41,7 +40,7 @@ public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTest { } @Test() - public void when_dataSetIsNull_expect_NullPointerException() { + public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) { assertThrows(NullPointerException.class,() -> { // Assemble RGBtoGrayscaleDataSetPreProcessor sut = new RGBtoGrayscaleDataSetPreProcessor(); @@ -53,7 +52,9 @@ public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTest { } @Test - public void when_dataSetIsEmpty_expect_EmptyDataSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_dataSetIsEmpty_expect_EmptyDataSet(Nd4jBackend backend) { // Assemble RGBtoGrayscaleDataSetPreProcessor sut = new RGBtoGrayscaleDataSetPreProcessor(); DataSet ds = new DataSet(null, null); @@ -66,7 +67,9 @@ public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTest { } @Test - public void when_colorsAreConverted_expect_grayScaleResult() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void when_colorsAreConverted_expect_grayScaleResult(Nd4jBackend backend) { // Assign int numChannels = 3; int height = 1; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/UnderSamplingPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/UnderSamplingPreProcessorTest.java index ed5568fea..84e1353db 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/UnderSamplingPreProcessorTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/UnderSamplingPreProcessorTest.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.dataset.api.preprocessor; import lombok.extern.slf4j.Slf4j; import net.jcip.annotations.NotThreadSafe; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.dataset.DataSet; @@ -48,9 +49,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue; * @author susaneraly */ @Slf4j -@RunWith(Parameterized.class) + @NotThreadSafe -public class UnderSamplingPreProcessorTest extends BaseNd4jTest { +public class UnderSamplingPreProcessorTest extends BaseNd4jTestWithBackends { int shortSeq = 10000; int longSeq = 20020; //not a perfect multiple of windowSize int window = 5000; @@ -58,19 +59,18 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { double targetDist = 0.3; double tolerancePerc = 0.03; //10% +/- because this is not a very large sample - public UnderSamplingPreProcessorTest(Nd4jBackend backend) { - super(backend); - } @Test - public void allMajority() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void allMajority(Nd4jBackend backend) { float[] someTargets = new float[] {0.01f, 0.1f, 0.5f}; DataSet d = allMajorityDataSet(false); DataSet dToPreProcess; for (int i = 0; i < someTargets.length; i++) { //if all majority default is to mask all time steps UnderSamplingByMaskingPreProcessor preProcessor = - new UnderSamplingByMaskingPreProcessor(someTargets[i], shortSeq / 2); + new UnderSamplingByMaskingPreProcessor(someTargets[i], shortSeq / 2); dToPreProcess = d.copy(); preProcessor.preProcess(dToPreProcess); INDArray exp = Nd4j.zeros(dToPreProcess.getLabelsMaskArray().shape()); @@ -83,18 +83,20 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { preProcessor.preProcess(dToPreProcess); INDArray percentagesNow = dToPreProcess.getLabelsMaskArray().sum(1).div(shortSeq); assertTrue(Nd4j.valueArrayOf(percentagesNow.shape(), 1 - someTargets[i]).castTo(Nd4j.defaultFloatingPointType()).equalsWithEps(percentagesNow, - tolerancePerc)); + tolerancePerc)); } } @Test - public void allMinority() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void allMinority(Nd4jBackend backend) { float[] someTargets = new float[] {0.01f, 0.1f, 0.5f}; DataSet d = allMinorityDataSet(false); DataSet dToPreProcess; for (int i = 0; i < someTargets.length; i++) { UnderSamplingByMaskingPreProcessor preProcessor = - new UnderSamplingByMaskingPreProcessor(someTargets[i], shortSeq / 2); + new UnderSamplingByMaskingPreProcessor(someTargets[i], shortSeq / 2); dToPreProcess = d.copy(); preProcessor.preProcess(dToPreProcess); //all minority classes present - check that no time steps are masked @@ -116,7 +118,9 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { Checks distribution of classes after preprocessing */ @Test - public void mixedDist() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void mixedDist(Nd4jBackend backend) { UnderSamplingByMaskingPreProcessor preProcessor = new UnderSamplingByMaskingPreProcessor(targetDist, window); @@ -135,7 +139,7 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { //check masks are zero where there are no time steps INDArray masks = dataSetToPreProcess.getLabelsMaskArray(); INDArray shouldBeAllZeros = - masks.get(NDArrayIndex.interval(0, 3), NDArrayIndex.interval(shortSeq, longSeq)); + masks.get(NDArrayIndex.interval(0, 3), NDArrayIndex.interval(shortSeq, longSeq)); assertEquals(Nd4j.zeros(shouldBeAllZeros.shape()), shouldBeAllZeros); //check distribution of masks in window, going backwards from last time step @@ -145,7 +149,7 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { int minIndex = min(0, maxIndex - window); INDArray maskWindow = masks.get(NDArrayIndex.all(), NDArrayIndex.interval(minIndex, maxIndex)); INDArray labelWindow = labels.get(NDArrayIndex.all(), NDArrayIndex.point(0), - NDArrayIndex.interval(minIndex, maxIndex)); + NDArrayIndex.interval(minIndex, maxIndex)); //calc minority class distribution INDArray minorityDist = labelWindow.mul(maskWindow).sum(1).div(maskWindow.sum(1)); @@ -173,7 +177,9 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { Also checks minority override */ @Test - public void mixedDistOneHot() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void mixedDistOneHot(Nd4jBackend backend) { //preprocessor should give 30% minority class for every "window" UnderSamplingByMaskingPreProcessor preProcessor = new UnderSamplingByMaskingPreProcessor(targetDist, window); @@ -194,7 +200,7 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { //check masks are zero where there were no time steps INDArray shouldBeAllZeros = - masks.get(NDArrayIndex.interval(0, 3), NDArrayIndex.interval(shortSeq, longSeq)); + masks.get(NDArrayIndex.interval(0, 3), NDArrayIndex.interval(shortSeq, longSeq)); assertEquals(Nd4j.zeros(shouldBeAllZeros.shape()), shouldBeAllZeros); //check distribution of masks in the window length, going backwards from last time step @@ -204,13 +210,13 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { int minIndex = min(0, maxIndex - window); INDArray maskWindow = masks.get(NDArrayIndex.all(), NDArrayIndex.interval(minIndex, maxIndex)); INDArray labelWindow = labels.get(NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.interval(minIndex, maxIndex)); + NDArrayIndex.interval(minIndex, maxIndex)); //calc minority class distribution after accounting for masks INDArray minorityClass = labelWindow.get(NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all()) - .mul(maskWindow); + .mul(maskWindow); INDArray majorityClass = labelWindow.get(NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all()) - .mul(maskWindow); + .mul(maskWindow); INDArray minorityDist = minorityClass.sum(1).div(majorityClass.add(minorityClass).sum(1)); if (j < shortSeq / window) { @@ -233,7 +239,9 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { //all the tests above into one multidataset @Test - public void testForMultiDataSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testForMultiDataSet(Nd4jBackend backend) { DataSet dataSetA = knownDistVariedDataSet(new float[] {0.8f, 0.1f, 0.2f}, false); DataSet dataSetB = knownDistVariedDataSet(new float[] {0.2f, 0.9f, 0.8f}, true); @@ -241,7 +249,7 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { targetDists.put(0, 0.5); //balance inputA targetDists.put(1, 0.3); //inputB dist = 0.2% UnderSamplingByMaskingMultiDataSetPreProcessor maskingMultiDataSetPreProcessor = - new UnderSamplingByMaskingMultiDataSetPreProcessor(targetDists, window); + new UnderSamplingByMaskingMultiDataSetPreProcessor(targetDists, window); maskingMultiDataSetPreProcessor.overrideMinorityDefault(1); MultiDataSet multiDataSet = fromDataSet(dataSetA, dataSetB); @@ -263,7 +271,7 @@ public class UnderSamplingPreProcessorTest extends BaseNd4jTest { //datasetB - override is switched so grab index=0 labels = multiDataSet.getLabels(1).get(NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.all()) - .mul(multiDataSet.getLabelsMaskArray(1)); + .mul(multiDataSet.getLabelsMaskArray(1)); minorityCount = labels.sum(1); seqCount = multiDataSet.getLabelsMaskArray(1).sum(1); minorityDist = minorityCount.div(seqCount); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java index 1a4300f96..aeb3db361 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestPCA.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.dimensionalityreduction; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -32,22 +33,19 @@ import org.nd4j.linalg.string.NDArrayStrings; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -@RunWith(Parameterized.class) -public class TestPCA extends BaseNd4jTest { - - public TestPCA(Nd4jBackend backend) { - super(backend); - } +public class TestPCA extends BaseNd4jTestWithBackends { @Test - public void testFactorDims() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFactorDims(Nd4jBackend backend) { int m = 13; int n = 4; double f[] = new double[] {7, 1, 11, 11, 7, 11, 3, 1, 2, 21, 1, 11, 10, 26, 29, 56, 31, 52, 55, 71, 31, 54, 47, - 40, 66, 68, 6, 15, 8, 8, 6, 9, 17, 22, 18, 4, 23, 9, 8, 60, 52, 20, 47, 33, 22, 6, 44, 22, 26, - 34, 12, 12}; + 40, 66, 68, 6, 15, 8, 8, 6, 9, 17, 22, 18, 4, 23, 9, 8, 60, 52, 20, 47, 33, 22, 6, 44, 22, 26, + 34, 12, 12}; INDArray A = Nd4j.create(f, new int[] {m, n}, 'f'); @@ -64,13 +62,15 @@ public class TestPCA extends BaseNd4jTest { } @Test - public void testFactorSVDTransposed() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFactorSVDTransposed(Nd4jBackend backend) { int m = 4; int n = 13; double f[] = new double[] {7, 1, 11, 11, 7, 11, 3, 1, 2, 21, 1, 11, 10, 26, 29, 56, 31, 52, 55, 71, 31, 54, 47, - 40, 66, 68, 6, 15, 8, 8, 6, 9, 17, 22, 18, 4, 23, 9, 8, 60, 52, 20, 47, 33, 22, 6, 44, 22, 26, - 34, 12, 12}; + 40, 66, 68, 6, 15, 8, 8, 6, 9, 17, 22, 18, 4, 23, 9, 8, 60, 52, 20, 47, 33, 22, 6, 44, 22, 26, + 34, 12, 12}; INDArray A = Nd4j.create(f, new long[] {m, n}, 'f'); @@ -87,13 +87,15 @@ public class TestPCA extends BaseNd4jTest { } @Test - public void testFactorVariance() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFactorVariance(Nd4jBackend backend) { int m = 13; int n = 4; double f[] = new double[] {7, 1, 11, 11, 7, 11, 3, 1, 2, 21, 1, 11, 10, 26, 29, 56, 31, 52, 55, 71, 31, 54, 47, - 40, 66, 68, 6, 15, 8, 8, 6, 9, 17, 22, 18, 4, 23, 9, 8, 60, 52, 20, 47, 33, 22, 6, 44, 22, 26, - 34, 12, 12}; + 40, 66, 68, 6, 15, 8, 8, 6, 9, 17, 22, 18, 4, 23, 9, 8, 60, 52, 20, 47, 33, 22, 6, 44, 22, 26, + 34, 12, 12}; INDArray A = Nd4j.create(f, new int[] {m, n}, 'f'); @@ -116,7 +118,9 @@ public class TestPCA extends BaseNd4jTest { * Test new PCA routines, added by Luke Czapla */ @Test - public void testPCA() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPCA(Nd4jBackend backend) { INDArray m = Nd4j.randn(10000, 16); // 10000 random correlated samples of 16 features to analyze m.getColumn(0).muli(4.84); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java index 0df37d632..534e03f50 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dimensionalityreduction/TestRandomProjection.java @@ -23,16 +23,14 @@ package org.nd4j.linalg.dimensionalityreduction; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import org.nd4j.linalg.indexing.BooleanIndexing; -import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.ArrayList; @@ -43,18 +41,16 @@ import static org.nd4j.linalg.dimensionalityreduction.RandomProjection.johnsonLi import static org.nd4j.linalg.dimensionalityreduction.RandomProjection.targetShape; @Disabled -@RunWith(Parameterized.class) -public class TestRandomProjection extends BaseNd4jTest { + +public class TestRandomProjection extends BaseNd4jTestWithBackends { INDArray z1 = Nd4j.createUninitialized(new int[]{(int)1e6, 1000}); - public TestRandomProjection(Nd4jBackend backend) { - super(backend); - } - @Test - public void testJohnsonLindenStraussDim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJohnsonLindenStraussDim(Nd4jBackend backend) { assertEquals(663, (int)johnsonLindenStraussMinDim((int) 1e6, 0.5).get(0)); assertTrue(johnsonLindenStraussMinDim((int) 1e6, 0.5).equals(new ArrayList(Arrays.asList(663)))); @@ -67,7 +63,9 @@ public class TestRandomProjection extends BaseNd4jTest { } @Test - public void testTargetShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTargetShape(Nd4jBackend backend) { assertArrayEquals(targetShape(z1, 0.5), new long[]{1000, 663}); assertArrayEquals(targetShape(Nd4j.createUninitialized(new int[]{(int)1e2, 225}), 0.5), new long[]{225, 221}); // non-changing estimate @@ -75,7 +73,9 @@ public class TestRandomProjection extends BaseNd4jTest { } @Test - public void testTargetEpsilonChecks() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTargetEpsilonChecks(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { // wrong rel. error targetShape(z1, 0.0); @@ -84,7 +84,9 @@ public class TestRandomProjection extends BaseNd4jTest { } @Test - public void testTargetShapeTooHigh() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTargetShapeTooHigh(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { // original dimension too small targetShape(Nd4j.createUninitialized(new int[]{(int)1e2, 1}), 0.5); @@ -99,27 +101,10 @@ public class TestRandomProjection extends BaseNd4jTest { } - private void makeRandomSparseData(int[] shape, double density) { - INDArray z1 = Nd4j.rand(shape); - // because this is rand with mean = 0, stdev = 1, abslessThan ~= density - BooleanIndexing.replaceWhere(z1, 0.0, Conditions.absLessThan(density)); - } - - - private void testRandomProjectionDeterministicForSameShape(){ - INDArray z1 = Nd4j.randn(1000, 500); - RandomProjection rp = new RandomProjection(50); - INDArray res1 = Nd4j.zeros(10000, 442); - rp.projecti(z1, res1); - - INDArray res2 = Nd4j.zeros(10000, 442); - rp.projecti(z1, res2); - - assertEquals(res1, res2); - } - @Test - public void testBasicEmbedding() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicEmbedding(Nd4jBackend backend) { INDArray z1 = Nd4j.randn(10000, 500); RandomProjection rp = new RandomProjection(0.5); INDArray res = Nd4j.zeros(10000, 442); @@ -128,7 +113,9 @@ public class TestRandomProjection extends BaseNd4jTest { } @Test - public void testEmbedding(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmbedding(Nd4jBackend backend) { INDArray z1 = Nd4j.randn(2000, 400); INDArray z2 = z1.dup(); INDArray result = Transforms.allEuclideanDistances(z1, z2, 1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java index 58f226c64..44abcf6b3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java @@ -25,9 +25,10 @@ import org.bytedeco.javacpp.FloatPointer; import org.bytedeco.javacpp.Pointer; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -49,14 +50,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** */ -@RunWith(Parameterized.class) -public class Nd4jTest extends BaseNd4jTest { - public Nd4jTest(Nd4jBackend backend) { - super(backend); - } + +public class Nd4jTest extends BaseNd4jTestWithBackends { @Test - public void testRandShapeAndRNG() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRandShapeAndRNG(Nd4jBackend backend) { INDArray ret = Nd4j.rand(new int[] {4, 2}, Nd4j.getRandomFactory().getNewRandomInstance(123)); INDArray ret2 = Nd4j.rand(new int[] {4, 2}, Nd4j.getRandomFactory().getNewRandomInstance(123)); @@ -64,21 +64,27 @@ public class Nd4jTest extends BaseNd4jTest { } @Test - public void testRandShapeAndMinMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRandShapeAndMinMax(Nd4jBackend backend) { INDArray ret = Nd4j.rand(new int[] {4, 2}, -0.125f, 0.125f, Nd4j.getRandomFactory().getNewRandomInstance(123)); INDArray ret2 = Nd4j.rand(new int[] {4, 2}, -0.125f, 0.125f, Nd4j.getRandomFactory().getNewRandomInstance(123)); assertEquals(ret, ret2); } @Test - public void testCreateShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateShape(Nd4jBackend backend) { INDArray ret = Nd4j.create(new int[] {4, 2}); assertEquals(ret.length(), 8); } @Test - public void testCreateFromList() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateFromList(Nd4jBackend backend) { List doubles = Arrays.asList(1.0, 2.0); INDArray NdarrayDobules = Nd4j.create(doubles); @@ -92,7 +98,9 @@ public class Nd4jTest extends BaseNd4jTest { } @Test - public void testGetRandom() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRandom(Nd4jBackend backend) { Random r = Nd4j.getRandom(); Random t = Nd4j.getRandom(); @@ -100,7 +108,9 @@ public class Nd4jTest extends BaseNd4jTest { } @Test - public void testGetRandomSetSeed() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRandomSetSeed(Nd4jBackend backend) { Random r = Nd4j.getRandom(); Random t = Nd4j.getRandom(); @@ -110,7 +120,9 @@ public class Nd4jTest extends BaseNd4jTest { } @Test - public void testOrdering() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOrdering(Nd4jBackend backend) { INDArray fNDArray = Nd4j.create(new float[] {1f}, NDArrayFactory.FORTRAN); assertEquals(NDArrayFactory.FORTRAN, fNDArray.ordering()); INDArray cNDArray = Nd4j.create(new float[] {1f}, NDArrayFactory.C); @@ -124,7 +136,9 @@ public class Nd4jTest extends BaseNd4jTest { @Test - public void testMean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMean(Nd4jBackend backend) { INDArray data = Nd4j.create(new double[] {4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8, 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4.}, @@ -138,7 +152,9 @@ public class Nd4jTest extends BaseNd4jTest { @Test - public void testVar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVar(Nd4jBackend backend) { INDArray data = Nd4j.create(new double[] {4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8, 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4.}, @@ -151,13 +167,17 @@ public class Nd4jTest extends BaseNd4jTest { } @Test - public void testVar2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVar2(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray var = arr.var(false, 0); assertEquals(Nd4j.create(new double[] {2.25, 2.25, 2.25}), var); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testExpandDims(){ final List> testMatricesC = NDArrayCreationUtil.getAllTestMatricesWithShape('c', 3, 5, 0xDEAD, DataType.DOUBLE); final List> testMatricesF = NDArrayCreationUtil.getAllTestMatricesWithShape('f', 7, 11, 0xBEEF, DataType.DOUBLE); @@ -188,6 +208,8 @@ public class Nd4jTest extends BaseNd4jTest { } } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSqueeze(){ final List> testMatricesC = NDArrayCreationUtil.getAllTestMatricesWithShape('c', 3, 1, 0xDEAD, DataType.DOUBLE); final List> testMatricesF = NDArrayCreationUtil.getAllTestMatricesWithShape('f', 7, 1, 0xBEEF, DataType.DOUBLE); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java index 745dc00d9..7421cc794 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java @@ -21,7 +21,9 @@ package org.nd4j.linalg.factory.ops; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -30,10 +32,7 @@ import org.nd4j.linalg.indexing.conditions.Conditions; import static org.junit.jupiter.api.Assertions.*; -public class NDBaseTest extends BaseNd4jTest { - public NDBaseTest(Nd4jBackend backend) { - super(backend); - } +public class NDBaseTest extends BaseNd4jTestWithBackends { @Override public char ordering(){ @@ -43,7 +42,9 @@ public class NDBaseTest extends BaseNd4jTest { // TODO: Comment from the review. We'll remove the new NDBase() at some point. @Test - public void testAll() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAll(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.BOOL, 3, 3); INDArray y = base.all(x, 1); @@ -52,7 +53,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testAny() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAny(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.eye(3).castTo(DataType.BOOL); INDArray y = base.any(x, 1); @@ -61,7 +64,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testArgmax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArgmax(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(new double[][]{{0.75, 0.5, 0.25}, {0.5, 0.75, 0.25}, {0.5, 0.25, 0.75}}); @@ -78,7 +83,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testArgmin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArgmin(Nd4jBackend backend) { //Copy Paste from argmax, replaced with argmin. NDBase base = new NDBase(); @@ -96,7 +103,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testConcat() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); INDArray y = Nd4j.ones(DataType.DOUBLE, 3, 3); @@ -109,7 +118,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testCumprod() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCumprod(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3, 3); INDArray y = base.cumprod(x, false, false, 0); @@ -123,7 +134,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testCumsum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCumsum(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3, 3); INDArray y = base.cumsum(x, false, false, 0); @@ -136,7 +149,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testDot() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDot(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 3); INDArray y = base.dot(x, x, 0); @@ -145,7 +160,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testDynamicpartition() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDynamicpartition(Nd4jBackend backend) { //Try to execute the sample in the code dcumentation: NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 5); @@ -157,7 +174,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testDynamicStitch() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDynamicStitch(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3, 3); //INDArray y = base.dynamicStitch(new INDArray[]{x, x}, 0); TODO: Fix @@ -165,7 +184,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScalarEq() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarEq(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); INDArray y = base.eq(x, 0.0); @@ -174,7 +195,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testEq() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEq(Nd4jBackend backend) { //element wise eq. NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); @@ -184,7 +207,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testExpandDims() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExpandDims(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(1,2).reshape(1,2); INDArray y = base.expandDims(x, 0); @@ -193,7 +218,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testFill() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFill(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(2, 2); INDArray y = base.fill(x, DataType.DOUBLE, 1.1); @@ -202,7 +229,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testGather() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGather(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); int[] ind = new int[]{0}; @@ -212,7 +241,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScalarGt() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarGt(Nd4jBackend backend) { //Scalar gt. NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); @@ -222,7 +253,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testGt() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGt(Nd4jBackend backend) { //element wise gt. NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); @@ -234,7 +267,9 @@ public class NDBaseTest extends BaseNd4jTest { @Test - public void testScalarGte() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarGte(Nd4jBackend backend) { //Scalar gte. NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); @@ -244,7 +279,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testGte() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGte(Nd4jBackend backend) { //element wise gte. NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); @@ -255,7 +292,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testIdentity() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIdentity(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); INDArray y = base.identity(x); @@ -263,7 +302,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testInvertPermutation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInvertPermutation(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(2,0,1); INDArray y = base.invertPermutation(x); @@ -272,7 +313,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testisNumericTensor() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testisNumericTensor(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); INDArray y = base.isNumericTensor(x); @@ -280,14 +323,18 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testLinspace() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLinspace(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray y = base.linspace(DataType.DOUBLE, 0.0, 9.0, 19); //TODO: test crashes. } @Test - public void testScalarLt() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarLt(Nd4jBackend backend) { //Scalar lt. NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); @@ -297,7 +344,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testLt() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLt(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x1 = Nd4j.zeros(DataType.DOUBLE, 3, 3); INDArray x = Nd4j.ones(DataType.DOUBLE, 3, 3); @@ -307,7 +356,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScalarLte() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarLte(Nd4jBackend backend) { //Scalar gt. NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); @@ -317,7 +368,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testLte() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLte(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x1 = Nd4j.zeros(DataType.DOUBLE, 3, 3); INDArray x = Nd4j.ones(DataType.DOUBLE, 3, 3); @@ -327,7 +380,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testMatchCondition() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatchCondition(Nd4jBackend backend) { // same test as TestMatchTransformOp, NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(1.0, 1.0, 1.0, 0.0, 1.0, 1.0); @@ -337,7 +392,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testMatchConditionCount() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatchConditionCount(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(1.0, 1.0, 1.0, 0.0, 1.0, 1.0); INDArray y = base.matchConditionCount(x, Conditions.epsEquals(0.0)); @@ -361,7 +418,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMax(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.eye(3).castTo(DataType.FLOAT); INDArray y = base.max(x, 0); @@ -374,7 +433,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testMean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMean(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.eye(3).castTo(DataType.FLOAT); INDArray y = base.mean(x, 0); @@ -387,7 +448,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMin(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.eye(3).castTo(DataType.FLOAT); INDArray y = base.min(x, 0); @@ -400,7 +463,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testMmulTranspose() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmulTranspose(Nd4jBackend backend) { INDArray x = Nd4j.rand(DataType.FLOAT, 4, 3); INDArray y = Nd4j.rand(DataType.FLOAT, 5, 4); INDArray exp = x.transpose().mmul(y.transpose()); @@ -409,7 +474,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testMmul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmul(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3, 3); INDArray x1 = Nd4j.eye(3).castTo(DataType.DOUBLE); @@ -418,7 +485,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScalarNeq() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarNeq(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); INDArray y = base.neq(x, 1.0); @@ -427,7 +496,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testNeq() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNeq(Nd4jBackend backend) { //element wise eq. NDBase base = new NDBase(); INDArray x = Nd4j.zeros(DataType.DOUBLE, 3, 3); @@ -438,7 +509,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testNorm1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm1(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.eye(3).castTo(DataType.FLOAT); INDArray y = base.norm1(x, 0); @@ -451,7 +524,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testNorm2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm2(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.eye(3).castTo(DataType.FLOAT); INDArray y = base.norm2(x, 0); @@ -464,7 +539,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testNormMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNormMax(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.eye(3).castTo(DataType.FLOAT); INDArray y = base.normmax(x, 0); @@ -477,7 +554,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testOneHot() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOneHot(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(0.0, 1.0, 2.0); INDArray y = base.oneHot(x, 1, 0, 1.0, 0.0); @@ -494,7 +573,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testOnesLike() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOnesLike(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(3, 3); INDArray y = base.onesLike(x); @@ -507,7 +588,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testPermute() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPermute(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(1, 6, 6).reshape(2, 3); INDArray y = base.permute(x, 1,0); @@ -515,7 +598,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testProd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testProd(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.eye(3).castTo(DataType.FLOAT); INDArray y = base.prod(x, 0); @@ -528,7 +613,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testRange() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRange(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray y = base.range(0.0, 3.0, 1.0, DataType.DOUBLE); INDArray y_exp = Nd4j.createFromArray(0.0, 1.0, 2.0); @@ -536,7 +623,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testRank() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRank(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.eye(3); INDArray y = base.rank(x); @@ -546,8 +635,10 @@ public class NDBaseTest extends BaseNd4jTest { } /* - @Test - public void testRepeat() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRepeat(Nd4jBackend backend) { fail("AB 2020/01/09 - Not sure what this op is supposed to do..."); NDBase base = new NDBase(); INDArray x = Nd4j.eye(3); @@ -558,7 +649,9 @@ public class NDBaseTest extends BaseNd4jTest { @Test - public void testReplaceWhere() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReplaceWhere(Nd4jBackend backend) { // test from BooleanIndexingTest. NDBase base = new NDBase(); INDArray array1 = Nd4j.createFromArray( 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0); @@ -570,7 +663,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testReshape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshape(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3, 3); INDArray shape = Nd4j.createFromArray(new long[] {3, 3}); @@ -580,7 +675,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testReverse() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverse(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 6).reshape(2, 3); INDArray y = base.reverse(x, 0); @@ -589,7 +686,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testReverseSequence() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverseSequence(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3,3); INDArray seq_kengths = Nd4j.createFromArray(2,3,1); @@ -604,7 +703,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScalarFloorMod() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarFloorMod(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3, 3); INDArray y = base.scalarFloorMod(x, 2.0); @@ -613,7 +714,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScalarMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarMax(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3, 3); INDArray y = base.scalarMax(x, 5.0); @@ -623,7 +726,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScalarMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarMin(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3, 3); INDArray y = base.scalarMin(x, 5.0); @@ -632,7 +737,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScalarSet() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarSet(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(1.0, 2.0, 0.0, 4.0, 5.0); INDArray y = base.scalarSet(x, 1.0); @@ -641,7 +748,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScatterAdd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterAdd(Nd4jBackend backend) { NDBase base = new NDBase(); //from testScatterOpGradients. @@ -656,7 +765,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScatterDiv() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterDiv(Nd4jBackend backend) { NDBase base = new NDBase(); //from testScatterOpGradients. @@ -671,7 +782,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScatterMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterMax(Nd4jBackend backend) { NDBase base = new NDBase(); //from testScatterOpGradients. @@ -686,7 +799,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScatterMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterMin(Nd4jBackend backend) { NDBase base = new NDBase(); //from testScatterOpGradients. @@ -701,7 +816,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScatterMul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterMul(Nd4jBackend backend) { NDBase base = new NDBase(); //from testScatterOpGradients. @@ -716,7 +833,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testScatterSub() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScatterSub(Nd4jBackend backend) { NDBase base = new NDBase(); //from testScatterOpGradients. @@ -733,7 +852,9 @@ public class NDBaseTest extends BaseNd4jTest { @Test - public void testSegmentMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSegmentMax(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(3, 6, 1, 4, 9,2, 2); INDArray segmentIDs = Nd4j.createFromArray(0,0,1,1,1,2,2); @@ -743,7 +864,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testSegmentMean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSegmentMean(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(3.0, 6.0, 1.0, 4.0, 9.0,2.0, 2.0); INDArray segmentIDs = Nd4j.createFromArray(0,0,1,1,1,2,2); @@ -753,7 +876,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testSegmentMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSegmentMin(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(3.0, 6.0, 1.0, 4.0, 9.0,2.0, 2.0); INDArray segmentIDs = Nd4j.createFromArray(0,0,1,1,1,2,2); @@ -763,7 +888,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testSegmentProd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSegmentProd(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(3.0, 6.0, 1.0, 4.0, 9.0,2.0, 2.0); INDArray segmentIDs = Nd4j.createFromArray(0,0,1,1,1,2,2); @@ -773,7 +900,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testSegmentSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSegmentSum(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(3.0, 6.0, 1.0, 4.0, 9.0,2.0, 2.0); INDArray segmentIDs = Nd4j.createFromArray(0,0,1,1,1,2,2); @@ -783,7 +912,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testSequenceMask() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSequenceMask(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray length = Nd4j.createFromArray(1, 3, 2); int maxlength = 5; @@ -798,7 +929,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShape(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.zeros(3,3); INDArray y = base.shape(x); @@ -807,7 +940,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testSize() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSize(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.zeros(3,3); INDArray y = base.size(x); @@ -815,7 +950,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testSizeAt() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSizeAt(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.zeros(10,20, 30); INDArray y = base.sizeAt(x, 1); @@ -823,7 +960,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testSlice() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSlice(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 6).reshape(2, 3); INDArray y = base.slice(x, new int[]{0,1}, 2,1); @@ -832,7 +971,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testSquaredNorm() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSquaredNorm(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3, 3); INDArray y = base.squaredNorm(x, 0); @@ -845,7 +986,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testSqueeze() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSqueeze(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 10).reshape(2,1,5); INDArray y = base.squeeze(x,1); @@ -854,7 +997,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testStack() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStack(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 3); INDArray y = base.stack(1 , x); @@ -862,7 +1007,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testStandardDeviation() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStandardDeviation(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4); INDArray y = base.standardDeviation(x, false, 0); @@ -875,7 +1022,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testStridedSlice() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedSlice(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3,3); INDArray y = base.stridedSlice(x, new long[]{0,1}, new long[] {3,3}, 2,1); @@ -885,7 +1034,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3,3); INDArray y = base.sum(x, 0); @@ -897,7 +1048,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testTensorMul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTensorMul(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3,3); INDArray y = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3,3); @@ -915,7 +1068,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testTile() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTile(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4).reshape(2,2); INDArray repeat = Nd4j.createFromArray(2, 3); @@ -929,7 +1084,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testTranspose() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTranspose(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3,3); INDArray y = base.transpose(x); @@ -938,7 +1095,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testUnsegmentMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnsegmentMax(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(1,3,2,6,4,9,8); INDArray segmentIDs = Nd4j.createFromArray(1,0,2,0,1,1,2); @@ -948,7 +1107,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testUnsegmentMean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnsegmentMean(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(1,3,2,6,4,9,8).castTo(DataType.FLOAT); INDArray segmentIDs = Nd4j.createFromArray(1,0,2,0,1,1,2); @@ -958,7 +1119,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testUnsegmentedMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnsegmentedMin(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(1,3,2,6,4,9,8); INDArray segmentIDs = Nd4j.createFromArray(1,0,2,0,1,1,2); @@ -968,7 +1131,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testUnsegmentProd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnsegmentProd(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(1,3,2,6,4,9,8); INDArray segmentIDs = Nd4j.createFromArray(1,0,2,0,1,1,2); @@ -978,7 +1143,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testUnsortedSegmentSqrtN() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnsortedSegmentSqrtN(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(1.0,3.0,2.0,6.0,4.0,9.0,8.0); INDArray segmentIDs = Nd4j.createFromArray(1,0,2,0,1,1,2); @@ -988,7 +1155,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testUnsortedSegmentSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnsortedSegmentSum(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.createFromArray(1,3,2,6,4,9,8); INDArray segmentIDs = Nd4j.createFromArray(1,0,2,0,1,1,2); @@ -998,7 +1167,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testVariance() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariance(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4); INDArray y = base.variance(x, false, 0); @@ -1011,7 +1182,9 @@ public class NDBaseTest extends BaseNd4jTest { } @Test - public void testZerosLike() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testZerosLike(Nd4jBackend backend) { NDBase base = new NDBase(); INDArray x = Nd4j.zeros(3,3); INDArray y = base.zerosLike(x); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java index a4c6f0527..d95c4503f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java @@ -21,10 +21,12 @@ package org.nd4j.linalg.factory.ops; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; @@ -34,10 +36,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -public class NDLossTest extends BaseNd4jTest { - public NDLossTest(Nd4jBackend backend) { - super(backend); - } +public class NDLossTest extends BaseNd4jTestWithBackends { @Override public char ordering(){ @@ -45,7 +44,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testAbsoluteDifference() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAbsoluteDifference(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -79,7 +80,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testCosineDistance() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCosineDistance(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -115,7 +118,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testHingeLoss() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHingeLoss(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -148,7 +153,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testHuberLoss() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHuberLoss(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -181,7 +188,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testL2Loss() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testL2Loss(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -199,7 +208,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testLogLoss() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogLoss(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -237,7 +248,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testLogPoisson() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogPoisson(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -270,7 +283,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testMeanPairwiseSquaredError() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMeanPairwiseSquaredError(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -304,7 +319,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testMeanSquaredError() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMeanSquaredError(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -338,7 +355,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testSigmoidCrossEntropy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSigmoidCrossEntropy(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -373,7 +392,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testSoftmaxCrossEntropy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftmaxCrossEntropy(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -410,7 +431,9 @@ public class NDLossTest extends BaseNd4jTest { } @Test - public void testSparseSoftmaxCrossEntropy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSparseSoftmaxCrossEntropy(Nd4jBackend backend) { SameDiff sd = SameDiff.create(); int nOut = 4; @@ -437,7 +460,9 @@ public class NDLossTest extends BaseNd4jTest { @Test - public void testWeightedCrossEntropyWithLogits() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testWeightedCrossEntropyWithLogits(Nd4jBackend backend) { // This one from SamediffTests.java SameDiff sameDiff = SameDiff.create(); INDArray targets = Nd4j.create(new long[]{1, 5}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java index ba26e181a..974784882 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java @@ -21,9 +21,11 @@ package org.nd4j.linalg.generated; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,10 +34,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -public class SDLinalgTest extends BaseNd4jTest { - public SDLinalgTest(Nd4jBackend backend) { - super(backend); - } +public class SDLinalgTest extends BaseNd4jTestWithBackends { @Override public char ordering(){ @@ -50,7 +49,9 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test - public void testCholesky() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCholesky(Nd4jBackend backend) { INDArray input = Nd4j.createFromArray( new float[]{ 10.f, 14.f, @@ -73,6 +74,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLstsq() { INDArray a = Nd4j.createFromArray(new float[]{ 1.f, 2.f, 3.f, 4.f, @@ -95,6 +98,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLu() { SDVariable sdInput = sameDiff.var(Nd4j.createFromArray(new double[]{ 1., 2., 3., 0., 2., 3., 0., 0., 7. @@ -109,6 +114,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMatrixBandPart() { INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); INDArray expected = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); @@ -119,6 +126,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testQr() { INDArray input = Nd4j.createFromArray(new double[]{ 12., -51., 4., @@ -132,7 +141,7 @@ public class SDLinalgTest extends BaseNd4jTest { 0.8464147390303179, -0.3912908119746455, 0.34312406418022884, 0.42320736951515897, 0.9040872694197354, -0.02927016186366648, -0.2821382463434393, 0.17042054976392634, 0.9328559865183932, - -0.07053456158585983, 0.01404065236547358, -0.00109937201747271, + -0.07053456158585983, 0.01404065236547358, -0.00109937201747271, 0.14106912317171966, -0.01665551070074392, -0.10577161246232346 }).reshape(5,3); @@ -151,6 +160,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSolve() { INDArray a = Nd4j.createFromArray(new float[] { 2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f @@ -172,6 +183,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTriangularSolve() { INDArray a = Nd4j.createFromArray(new float[] { 0.7788f, 0.8012f, 0.7244f, @@ -199,6 +212,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCross() { INDArray a = Nd4j.createFromArray(new double[]{1, 2, 3}); INDArray b = Nd4j.createFromArray(new double[]{6, 7, 8}); @@ -212,6 +227,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDiag() { INDArray x = Nd4j.createFromArray(new double[]{1,2}); INDArray expected = Nd4j.createFromArray(new double[]{1,0,0,2}).reshape(2,2); @@ -223,6 +240,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDiagPart() { INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4).reshape(2,2); INDArray expected = Nd4j.createFromArray(new double[]{1,4}); @@ -234,6 +253,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLogdet() { INDArray x = Nd4j.createFromArray(new double[]{ 4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8 @@ -247,6 +268,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSvd() { INDArray x = Nd4j.createFromArray(new double[]{ 0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f @@ -259,6 +282,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLogdetName() { INDArray x = Nd4j.createFromArray(new double[]{ 4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8 @@ -271,6 +296,8 @@ public class SDLinalgTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testQrNames() { INDArray input = Nd4j.createFromArray(new double[]{ 12., -51., 4., diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java index 5c465317c..a8c66f4e7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.indexing; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; @@ -43,80 +44,97 @@ import java.util.Collections; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class BooleanIndexingTest extends BaseNd4jTest { - public BooleanIndexingTest(Nd4jBackend backend) { - super(backend); - } + +public class BooleanIndexingTest extends BaseNd4jTestWithBackends { /* 1D array checks */ @Test - public void testAnd1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAnd1(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); assertTrue(BooleanIndexing.and(array, Conditions.greaterThan(0.5f))); } @Test - public void testAnd2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAnd2(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); assertTrue(BooleanIndexing.and(array, Conditions.lessThan(6.0f))); } @Test - public void testAnd3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAnd3(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); assertFalse(BooleanIndexing.and(array, Conditions.lessThan(5.0f))); } @Test - public void testAnd4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAnd4(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); assertFalse(BooleanIndexing.and(array, Conditions.greaterThan(4.0f))); } @Test - public void testAnd5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAnd5(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[] {1e-5f, 1e-5f, 1e-5f, 1e-5f, 1e-5f}); assertTrue(BooleanIndexing.and(array, Conditions.greaterThanOrEqual(1e-5f))); } @Test - public void testAnd6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAnd6(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[] {1e-5f, 1e-5f, 1e-5f, 1e-5f, 1e-5f}); assertFalse(BooleanIndexing.and(array, Conditions.lessThan(1e-5f))); } @Test - public void testAnd7() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAnd7(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[] {1e-5f, 1e-5f, 1e-5f, 1e-5f, 1e-5f}); assertTrue(BooleanIndexing.and(array, Conditions.equals(1e-5f))); } @Test - public void testOr1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOr1(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); assertTrue(BooleanIndexing.or(array, Conditions.greaterThan(3.0f))); } @Test - public void testOr2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOr2(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); assertTrue(BooleanIndexing.or(array, Conditions.lessThan(3.0f))); } @Test - public void testOr3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOr3(Nd4jBackend backend) { INDArray array = Nd4j.create(new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); assertFalse(BooleanIndexing.or(array, Conditions.greaterThan(6.0f))); @@ -127,14 +145,18 @@ public class BooleanIndexingTest extends BaseNd4jTest { */ @Test - public void test2dAnd1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test2dAnd1(Nd4jBackend backend) { INDArray array = Nd4j.zeros(10, 10); assertTrue(BooleanIndexing.and(array, Conditions.equals(0f))); } @Test - public void test2dAnd2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test2dAnd2(Nd4jBackend backend) { INDArray array = Nd4j.zeros(10, 10); array.slice(4).putScalar(2, 1e-5f); // System.out.println(array); @@ -145,7 +167,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void test2dAnd3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test2dAnd3(Nd4jBackend backend) { INDArray array = Nd4j.zeros(10, 10); array.slice(4).putScalar(2, 1e-5f); @@ -154,7 +178,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void test2dAnd4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test2dAnd4(Nd4jBackend backend) { INDArray array = Nd4j.zeros(10, 10); array.slice(4).putScalar(2, 1e-5f); @@ -169,7 +195,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { * @throws Exception */ @Test - public void testSliceAssign1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSliceAssign1(Nd4jBackend backend) { INDArray array = Nd4j.zeros(4, 4); INDArray patch = Nd4j.create(new float[] {1e-5f, 1e-5f, 1e-5f}); @@ -190,7 +218,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testConditionalAssign1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConditionalAssign1(Nd4jBackend backend) { INDArray array1 = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7}); INDArray array2 = Nd4j.create(new double[] {7, 6, 5, 4, 3, 2, 1}); INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 3, 2, 1}); @@ -201,7 +231,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testCaSTransform1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCaSTransform1(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5}); INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 5}); @@ -211,7 +243,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testCaSTransform2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCaSTransform2(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5}); INDArray comp = Nd4j.create(new double[] {3, 2, 3, 4, 5}); @@ -221,7 +255,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testCaSPairwiseTransform1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCaSPairwiseTransform1(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5}); INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 5}); @@ -231,7 +267,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testCaRPairwiseTransform1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCaRPairwiseTransform1(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {1, 2, 0, 4, 5}); INDArray comp = Nd4j.create(new double[] {1, 2, 3, 4, 5}); @@ -241,7 +279,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testCaSPairwiseTransform2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCaSPairwiseTransform2(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {1, 2, 0, 4, 5}); INDArray y = Nd4j.create(new double[] {2, 4, 3, 0, 5}); INDArray comp = Nd4j.create(new double[] {2, 4, 3, 4, 5}); @@ -252,7 +292,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testCaRPairwiseTransform2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCaRPairwiseTransform2(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {1, 2, 0, 4, 5}); INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5}); INDArray comp = Nd4j.create(new double[] {2, 4, 0, 4, 5}); @@ -263,7 +305,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testCaSPairwiseTransform3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCaSPairwiseTransform3(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {1, 2, 0, 4, 5}); INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5}); INDArray comp = Nd4j.create(new double[] {2, 4, 3, 4, 5}); @@ -274,7 +318,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testCaRPairwiseTransform3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCaRPairwiseTransform3(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {1, 2, 0, 4, 5}); INDArray y = Nd4j.create(new double[] {2, 4, 3, 4, 5}); INDArray comp = Nd4j.create(new double[] {2, 2, 3, 4, 5}); @@ -286,7 +332,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { @Test - public void testMatchConditionAllDimensions1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatchConditionAllDimensions1(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); int val = (int) Nd4j.getExecutioner().exec(new MatchCondition(array, Conditions.lessThan(5))) @@ -296,7 +344,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testMatchConditionAllDimensions2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatchConditionAllDimensions2(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {0, 1, 2, 3, Double.NaN, 5, 6, 7, 8, 9}); int val = (int) Nd4j.getExecutioner().exec(new MatchCondition(array, Conditions.isNan())) @@ -306,7 +356,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testMatchConditionAllDimensions3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatchConditionAllDimensions3(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {0, 1, 2, 3, Double.NEGATIVE_INFINITY, 5, 6, 7, 8, 9}); int val = (int) Nd4j.getExecutioner() @@ -316,7 +368,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testMatchConditionAlongDimension1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatchConditionAlongDimension1(Nd4jBackend backend) { INDArray array = Nd4j.ones(3, 10); array.getRow(2).assign(0.0); @@ -328,7 +382,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testMatchConditionAlongDimension2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatchConditionAlongDimension2(Nd4jBackend backend) { INDArray array = Nd4j.ones(3, 10); array.getRow(2).assign(0.0).putScalar(0, 1.0); @@ -342,7 +398,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testMatchConditionAlongDimension3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatchConditionAlongDimension3(Nd4jBackend backend) { INDArray array = Nd4j.ones(3, 10); array.getRow(2).assign(0.0).putScalar(0, 1.0); @@ -355,7 +413,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { @Test - public void testConditionalUpdate() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConditionalUpdate(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(-2, 2, 5, DataType.DOUBLE); INDArray ones = Nd4j.ones(DataType.DOUBLE, 5); INDArray exp = Nd4j.create(new double[] {1, 1, 0, 1, 1}); @@ -368,7 +428,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { @Test - public void testFirstIndex1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFirstIndex1(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}); INDArray result = BooleanIndexing.firstIndex(arr, Conditions.greaterThanOrEqual(3)); @@ -376,7 +438,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testFirstIndex2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFirstIndex2(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}); INDArray result = BooleanIndexing.firstIndex(arr, Conditions.lessThan(3)); @@ -384,7 +448,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testLastIndex1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLastIndex1(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}); INDArray result = BooleanIndexing.lastIndex(arr, Conditions.greaterThanOrEqual(3)); @@ -392,7 +458,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testFirstIndex2D() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFirstIndex2D(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[] {1, 2, 3, 0, 1, 3, 7, 8, 9}).reshape('c', 3, 3); INDArray result = BooleanIndexing.firstIndex(arr, Conditions.greaterThanOrEqual(2), 1); INDArray exp = Nd4j.create(new long[] {1, 2, 0}, new long[]{3}, DataType.LONG); @@ -401,7 +469,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testLastIndex2D() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLastIndex2D(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[] {1, 2, 3, 0, 1, 3, 7, 8, 0}).reshape('c', 3, 3); INDArray result = BooleanIndexing.lastIndex(arr, Conditions.greaterThanOrEqual(2), 1); INDArray exp = Nd4j.create(new long[] {2, 2, 1}, new long[]{3}, DataType.LONG); @@ -410,7 +480,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testEpsEquals1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEpsEquals1(Nd4jBackend backend) { INDArray array = Nd4j.create(new double[] {-1, -1, -1e-8, 1e-8, 1, 1}); MatchCondition condition = new MatchCondition(array, Conditions.epsEquals(0.0)); int numZeroes = Nd4j.getExecutioner().exec(condition).getInt(0); @@ -419,7 +491,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testChooseNonZero() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testChooseNonZero(Nd4jBackend backend) { INDArray testArr = Nd4j.create(new double[] { 0.00, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, 1.98, 1.99, 2.00, 2.00, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, 2.16, 2.16, 2.16, 2.16, 2.16, 2.17 }); @@ -431,7 +505,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testChooseBasic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testChooseBasic(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(true); INDArray arr = Nd4j.linspace(1,4,4, Nd4j.dataType()).reshape(2,2); @@ -441,14 +517,18 @@ public class BooleanIndexingTest extends BaseNd4jTest { @Test - public void testChooseGreaterThanZero() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testChooseGreaterThanZero(Nd4jBackend backend) { INDArray zero = Nd4j.linspace(0,4,4, Nd4j.dataType()); INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{zero},Arrays.asList(0.0), Collections.emptyList(),new GreaterThan()); assertEquals(3, filtered.length()); } @Test - public void testChooseNone() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testChooseNone(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(true); INDArray arr = Nd4j.linspace(1,4,4, Nd4j.dataType()).reshape(2,2); @@ -458,7 +538,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { @Test - public void testWhere() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testWhere(Nd4jBackend backend) { INDArray data = Nd4j.create(4); INDArray mask = Nd4j.create(DataType.BOOL, 4); INDArray put = Nd4j.create(4); @@ -484,7 +566,9 @@ public class BooleanIndexingTest extends BaseNd4jTest { } @Test - public void testEpsStuff_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEpsStuff_1(Nd4jBackend backend) { val dtype = Nd4j.dataType(); val array = Nd4j.create(new float[]{0.001f, 5e-6f, 5e-6f, 5e-6f, 5e-6f}); val exp = Nd4j.create(new float[]{0.001f, 1.0f, 1.0f, 1.0f, 1.0f}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/TransformsTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/TransformsTest.java index f4b59fdb1..68bc330ed 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/TransformsTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/TransformsTest.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.indexing; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -36,16 +37,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -@RunWith(Parameterized.class) -public class TransformsTest extends BaseNd4jTest { - public TransformsTest(Nd4jBackend backend) { - super(backend); - } +public class TransformsTest extends BaseNd4jTestWithBackends { + @Test - public void testEq1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEq1(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 1, 2, 1}); INDArray exp = Nd4j.create(new boolean[] {false, false, true, false}); @@ -55,7 +55,9 @@ public class TransformsTest extends BaseNd4jTest { } @Test - public void testNEq1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNEq1(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 1, 2, 1}); INDArray exp = Nd4j.create(new boolean[] {true, false, true, false}); @@ -65,7 +67,9 @@ public class TransformsTest extends BaseNd4jTest { } @Test - public void testLT1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLT1(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 1, 2, 1}); INDArray exp = Nd4j.create(new boolean[] {true, true, false, true}); @@ -76,7 +80,9 @@ public class TransformsTest extends BaseNd4jTest { @Test - public void testGT1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGT1(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 1, 2, 4}); INDArray exp = Nd4j.create(new boolean[] {false, false, true, true}); @@ -87,7 +93,9 @@ public class TransformsTest extends BaseNd4jTest { @Test - public void testScalarMinMax1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarMinMax1(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {1, 3, 5, 7}); INDArray xCopy = x.dup(); INDArray exp1 = Nd4j.create(new double[] {1, 3, 5, 7}); @@ -110,7 +118,9 @@ public class TransformsTest extends BaseNd4jTest { } @Test - public void testArrayMinMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArrayMinMax(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {1, 3, 5, 7}); INDArray y = Nd4j.create(new double[] {2, 2, 6, 6}); INDArray xCopy = x.dup(); @@ -143,7 +153,9 @@ public class TransformsTest extends BaseNd4jTest { } @Test - public void testAnd1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAnd1(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 0, 1, 0, 0}); INDArray y = Nd4j.create(new double[] {0, 0, 1, 1, 0}); INDArray e = Nd4j.create(new boolean[] {false, false, true, false, false}); @@ -154,7 +166,9 @@ public class TransformsTest extends BaseNd4jTest { } @Test - public void testOr1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOr1(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 0, 1, 0, 0}); INDArray y = Nd4j.create(new double[] {0, 0, 1, 1, 0}); val e = Nd4j.create(new boolean[] {false, false, true, true, false}); @@ -165,7 +179,9 @@ public class TransformsTest extends BaseNd4jTest { } @Test - public void testXor1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testXor1(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 0, 1, 0, 0}); INDArray y = Nd4j.create(new double[] {0, 0, 1, 1, 0}); INDArray exp = Nd4j.create(new boolean[] {false, false, false, true, false}); @@ -176,7 +192,9 @@ public class TransformsTest extends BaseNd4jTest { } @Test - public void testNot1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNot1(Nd4jBackend backend) { INDArray x = Nd4j.create(new double[] {0, 0, 1, 0, 0}); INDArray exp = Nd4j.create(new boolean[] {false, false, true, false, false}); @@ -186,7 +204,9 @@ public class TransformsTest extends BaseNd4jTest { } @Test - public void testSlice_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSlice_1(Nd4jBackend backend) { val arr = Nd4j.linspace(1,4, 4, DataType.FLOAT).reshape(2, 2, 1); val exp0 = Nd4j.create(new float[]{1, 2}, new int[] {2, 1}); val exp1 = Nd4j.create(new float[]{3, 4}, new int[] {2, 1}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/inverse/TestInvertMatrices.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/inverse/TestInvertMatrices.java index 22e0de225..c326b7890 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/inverse/TestInvertMatrices.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/inverse/TestInvertMatrices.java @@ -25,9 +25,10 @@ import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.checkutil.CheckUtil; @@ -40,16 +41,15 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class TestInvertMatrices extends BaseNd4jTest { + +public class TestInvertMatrices extends BaseNd4jTestWithBackends { - public TestInvertMatrices(Nd4jBackend backend) { - super(backend); - } @Test - public void testInverse() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInverse(Nd4jBackend backend) { RealMatrix matrix = new Array2DRowRealMatrix(new double[][] {{1, 2}, {3, 4}}); RealMatrix inverse = MatrixUtils.inverse(matrix); @@ -62,7 +62,9 @@ public class TestInvertMatrices extends BaseNd4jTest { } @Test - public void testInverseComparison() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInverseComparison(Nd4jBackend backend) { List> list = NDArrayCreationUtil.getAllTestMatricesWithShape(10, 10, 12345, DataType.DOUBLE); @@ -79,7 +81,9 @@ public class TestInvertMatrices extends BaseNd4jTest { } @Test - public void testInvalidMatrixInversion() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInvalidMatrixInversion(Nd4jBackend backend) { try { InvertMatrix.invert(Nd4j.create(5, 4), false); fail("No exception thrown for invalid input"); @@ -100,6 +104,8 @@ public class TestInvertMatrices extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testInvertMatrixScalar(){ INDArray in = Nd4j.valueArrayOf(new int[]{1,1}, 2); INDArray out1 = InvertMatrix.invert(in, false); @@ -115,7 +121,9 @@ public class TestInvertMatrices extends BaseNd4jTest { * Example from: here */ @Test - public void testLeftPseudoInvert() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLeftPseudoInvert(Nd4jBackend backend) { INDArray X = Nd4j.create(new double[][]{{1, 2}, {3, 4}, {5, 6}}); INDArray expectedLeftInverse = Nd4j.create(new double[][]{{-16, -4, 8}, {13, 4, -5}}).mul(1 / 12d); INDArray leftInverse = InvertMatrix.pLeftInvert(X, false); @@ -162,7 +170,9 @@ public class TestInvertMatrices extends BaseNd4jTest { * Example from: here */ @Test - public void testRightPseudoInvert() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRightPseudoInvert(Nd4jBackend backend) { INDArray X = Nd4j.create(new double[][]{{1, 2}, {3, 4}, {5, 6}}).transpose(); INDArray expectedRightInverse = Nd4j.create(new double[][]{{-16, 13}, {-4, 4}, {8, -5}}).mul(1 / 12d); INDArray rightInverse = InvertMatrix.pRightInvert(X, false); @@ -190,8 +200,10 @@ public class TestInvertMatrices extends BaseNd4jTest { /** * Try to compute the right pseudo inverse of a matrix without full row rank (x1 = 2*x2) */ - @Test() - public void testRightPseudoInvertWithNonFullRowRank() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRightPseudoInvertWithNonFullRowRank(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { INDArray X = Nd4j.create(new double[][]{{1, 2}, {3, 6}, {5, 10}}).transpose(); INDArray rightInverse = InvertMatrix.pRightInvert(X, false); @@ -202,8 +214,10 @@ public class TestInvertMatrices extends BaseNd4jTest { /** * Try to compute the left pseudo inverse of a matrix without full column rank (x1 = 2*x2) */ - @Test() - public void testLeftPseudoInvertWithNonFullColumnRank() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLeftPseudoInvertWithNonFullColumnRank(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { INDArray X = Nd4j.create(new double[][]{{1, 2}, {3, 6}, {5, 10}}); INDArray leftInverse = InvertMatrix.pLeftInvert(X, false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsC.java index 7fc8e85d7..5f1d0427c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsC.java @@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -35,14 +36,9 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@RunWith(Parameterized.class) -public class LapackTestsC extends BaseNd4jTest { - DataType initialType; - public LapackTestsC(Nd4jBackend backend) { - super(backend); - initialType = Nd4j.dataType(); - } +public class LapackTestsC extends BaseNd4jTestWithBackends { + DataType initialType = Nd4j.dataType(); @BeforeEach public void setUp() { @@ -55,10 +51,12 @@ public class LapackTestsC extends BaseNd4jTest { } @Test - public void testGetRF1DifferentOrders() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRF1DifferentOrders(Nd4jBackend backend) { INDArray a = Nd4j.linspace(1, 9, 9, Nd4j.dataType()).reshape(3, 3); INDArray exp = Nd4j.create(new double[] {7.0, 8.0, 9.0, 0.14285715, 0.85714287, 1.7142857, 0.5714286, 0.5, 0.0}, - new int[] {3, 3}, 'c'); + new int[] {3, 3}, 'c'); INDArray r = Nd4j.getNDArrayFactory().lapack().getrf(a); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsF.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsF.java index 1c27010ae..c721dbf24 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsF.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lapack/LapackTestsF.java @@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -35,14 +36,9 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@RunWith(Parameterized.class) -public class LapackTestsF extends BaseNd4jTest { - DataType initialType; - public LapackTestsF(Nd4jBackend backend) { - super(backend); - initialType = Nd4j.dataType(); - } +public class LapackTestsF extends BaseNd4jTestWithBackends { + DataType initialType = Nd4j.dataType(); @BeforeEach public void setUp() { @@ -54,8 +50,10 @@ public class LapackTestsF extends BaseNd4jTest { Nd4j.setDataType(initialType); } - @Test - public void testGetRF1DifferentOrders() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRF1DifferentOrders(Nd4jBackend backend) { INDArray a = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9}, new int[] {3, 3}, 'c').dup('f'); INDArray exp = Nd4j.create(new double[] {7.0, 8.0, 9.0, 0.14285715, 0.85714287, 1.7142857, 0.5714286, 0.5, 0.0}, new int[] {3, 3}, 'c').dup('f'); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java index 897098ace..5f71ab417 100755 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.learning; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.distribution.Distribution; import org.nd4j.linalg.factory.Nd4j; @@ -37,16 +38,15 @@ import org.nd4j.linalg.learning.config.Nesterovs; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) -public class UpdaterTest extends BaseNd4jTest { - public UpdaterTest(Nd4jBackend backend) { - super(backend); - } +public class UpdaterTest extends BaseNd4jTestWithBackends { + @Test - public void testAdaGradLegacy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdaGradLegacy(Nd4jBackend backend) { int rows = 1; int cols = 1; @@ -59,7 +59,9 @@ public class UpdaterTest extends BaseNd4jTest { } @Test - public void testNesterovs() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNesterovs(Nd4jBackend backend) { int rows = 10; int cols = 2; @@ -78,7 +80,9 @@ public class UpdaterTest extends BaseNd4jTest { } @Test - public void testAdaGrad() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdaGrad(Nd4jBackend backend) { int rows = 10; int cols = 2; @@ -98,7 +102,9 @@ public class UpdaterTest extends BaseNd4jTest { } @Test - public void testAdaDelta() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdaDelta(Nd4jBackend backend) { int rows = 10; int cols = 2; @@ -118,7 +124,9 @@ public class UpdaterTest extends BaseNd4jTest { } @Test - public void testAdam() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdam(Nd4jBackend backend) { int rows = 10; int cols = 2; @@ -138,7 +146,9 @@ public class UpdaterTest extends BaseNd4jTest { } @Test - public void testNadam() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNadam(Nd4jBackend backend) { int rows = 10; int cols = 2; @@ -157,7 +167,9 @@ public class UpdaterTest extends BaseNd4jTest { } @Test - public void testAdaMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdaMax(Nd4jBackend backend) { int rows = 10; int cols = 2; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java index 27409e0d6..e4d6a8099 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java @@ -21,7 +21,9 @@ package org.nd4j.linalg.learning; import lombok.val; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater; @@ -42,11 +44,8 @@ import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; -public class UpdaterValidation extends BaseNd4jTest { +public class UpdaterValidation extends BaseNd4jTestWithBackends { - public UpdaterValidation(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -54,7 +53,9 @@ public class UpdaterValidation extends BaseNd4jTest { } @Test - public void testAdaDeltaUpdater(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdaDeltaUpdater(Nd4jBackend backend) { double rho = 0.95; double epsilon = 1e-6; @@ -93,7 +94,9 @@ public class UpdaterValidation extends BaseNd4jTest { } @Test - public void testAdaGradUpdater(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdaGradUpdater(Nd4jBackend backend) { double lr = 0.1; double epsilon = 1e-6; @@ -127,7 +130,9 @@ public class UpdaterValidation extends BaseNd4jTest { @Test - public void testAdamUpdater(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdamUpdater(Nd4jBackend backend) { double lr = 1e-3; double beta1 = 0.9; @@ -169,7 +174,9 @@ public class UpdaterValidation extends BaseNd4jTest { } @Test - public void testAdaMaxUpdater(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdaMaxUpdater(Nd4jBackend backend) { double lr = 1e-3; double beta1 = 0.9; double beta2 = 0.999; @@ -210,7 +217,9 @@ public class UpdaterValidation extends BaseNd4jTest { } @Test - public void testAmsGradUpdater(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAmsGradUpdater(Nd4jBackend backend) { double lr = 1e-3; double beta1 = 0.9; double beta2 = 0.999; @@ -257,7 +266,9 @@ public class UpdaterValidation extends BaseNd4jTest { } @Test - public void testNadamUpdater(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNadamUpdater(Nd4jBackend backend) { double lr = 1e-3; double beta1 = 0.9; @@ -299,7 +310,9 @@ public class UpdaterValidation extends BaseNd4jTest { } @Test - public void testNesterovUpdater(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNesterovUpdater(Nd4jBackend backend) { double lr = 0.1; double momentum = 0.9; @@ -331,7 +344,9 @@ public class UpdaterValidation extends BaseNd4jTest { } @Test - public void testRmsPropUpdater(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRmsPropUpdater(Nd4jBackend backend) { double lr = 0.1; double decay = 0.95; @@ -365,7 +380,9 @@ public class UpdaterValidation extends BaseNd4jTest { } @Test - public void testSgdUpdater(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSgdUpdater(Nd4jBackend backend) { double lr = 0.1; SgdUpdater u = (SgdUpdater) new Sgd(lr).instantiate((Map)null, true); @@ -386,8 +403,10 @@ public class UpdaterValidation extends BaseNd4jTest { /* - @Test - public void createUpdaterTestCases(){ + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void createUpdaterTestCases(Nd4jBackend backend) { Nd4j.create(1); Nd4j.getRandom().setSeed(12345); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java index c1a75fbbd..3668142f6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java @@ -21,7 +21,9 @@ package org.nd4j.linalg.lossfunctions; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -47,14 +49,13 @@ import org.nd4j.shade.jackson.databind.SerializationFeature; import static org.junit.jupiter.api.Assertions.assertEquals; -public class LossFunctionJson extends BaseNd4jTest { +public class LossFunctionJson extends BaseNd4jTestWithBackends { - public LossFunctionJson(Nd4jBackend backend) { - super(backend); - } - @Test - public void testJsonSerialization() throws Exception { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJsonSerialization(Nd4jBackend backend) throws Exception { INDArray w = Nd4j.create(new double[] {1.0, 2.0, 3.0}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java index a39f715f0..ef585e825 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java @@ -21,7 +21,9 @@ package org.nd4j.linalg.lossfunctions; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSoftmax; @@ -47,14 +49,13 @@ import static junit.framework.TestCase.assertFalse; import static junit.framework.TestCase.assertTrue; import static org.junit.jupiter.api.Assertions.assertEquals; -public class LossFunctionTest extends BaseNd4jTest { +public class LossFunctionTest extends BaseNd4jTestWithBackends { - public LossFunctionTest(Nd4jBackend backend) { - super(backend); - } @Test - public void testClippingXENT() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testClippingXENT(Nd4jBackend backend) { ILossFunction l1 = new LossBinaryXENT(0); ILossFunction l2 = new LossBinaryXENT(); @@ -83,7 +84,9 @@ public class LossFunctionTest extends BaseNd4jTest { } @Test - public void testWeightedLossFunctionDTypes(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testWeightedLossFunctionDTypes(Nd4jBackend backend){ for(DataType activationsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){ for(DataType weightsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java index a4ef0632d..445938020 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java @@ -22,18 +22,17 @@ package org.nd4j.linalg.lossfunctions; import org.junit.Assert; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; -public class TestLossFunctionsSizeChecks extends BaseNd4jTest { +public class TestLossFunctionsSizeChecks extends BaseNd4jTestWithBackends { - public TestLossFunctionsSizeChecks(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -41,13 +40,15 @@ public class TestLossFunctionsSizeChecks extends BaseNd4jTest { } @Test - public void testL2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testL2(Nd4jBackend backend) { LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.XENT, - LossFunction.MCXENT, LossFunction.SQUARED_LOSS, LossFunction.RECONSTRUCTION_CROSSENTROPY, - LossFunction.NEGATIVELOGLIKELIHOOD, LossFunction.COSINE_PROXIMITY, LossFunction.HINGE, - LossFunction.SQUARED_HINGE, LossFunction.KL_DIVERGENCE, LossFunction.MEAN_ABSOLUTE_ERROR, - LossFunction.L2, LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR, - LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR, LossFunction.POISSON}; + LossFunction.MCXENT, LossFunction.SQUARED_LOSS, LossFunction.RECONSTRUCTION_CROSSENTROPY, + LossFunction.NEGATIVELOGLIKELIHOOD, LossFunction.COSINE_PROXIMITY, LossFunction.HINGE, + LossFunction.SQUARED_HINGE, LossFunction.KL_DIVERGENCE, LossFunction.MEAN_ABSOLUTE_ERROR, + LossFunction.L2, LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR, + LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR, LossFunction.POISSON}; testLossFunctions(lossFunctionList); } @@ -69,34 +70,34 @@ public class TestLossFunctionsSizeChecks extends BaseNd4jTest { INDArray labels = Nd4j.create(100, 32); INDArray preOutput = Nd4j.create(100, 44); double score = loss.computeScore(labels, preOutput, Activation.IDENTITY.getActivationFunction(), null, - true); + true); Assert.assertFalse( - "Loss function " + loss.toString() - + "did not check for size mismatch. This should fail to compute an activation function because the sizes of the vectors are not equal", - true); + "Loss function " + loss.toString() + + "did not check for size mismatch. This should fail to compute an activation function because the sizes of the vectors are not equal", + true); } catch (IllegalArgumentException ex) { String exceptionMessage = ex.getMessage(); Assert.assertTrue( - "Loss function exception " + loss.toString() - + " did not indicate size mismatch when vectors of incorrect size were used.", - exceptionMessage.contains("shapes")); + "Loss function exception " + loss.toString() + + " did not indicate size mismatch when vectors of incorrect size were used.", + exceptionMessage.contains("shapes")); } try { INDArray labels = Nd4j.create(100, 32); INDArray preOutput = Nd4j.create(100, 44); INDArray gradient = - loss.computeGradient(labels, preOutput, Activation.IDENTITY.getActivationFunction(), null); + loss.computeGradient(labels, preOutput, Activation.IDENTITY.getActivationFunction(), null); Assert.assertFalse( - "Loss function " + loss.toString() - + "did not check for size mismatch. This should fail to compute an activation function because the sizes of the vectors are not equal", - true); + "Loss function " + loss.toString() + + "did not check for size mismatch. This should fail to compute an activation function because the sizes of the vectors are not equal", + true); } catch (IllegalArgumentException ex) { String exceptionMessage = ex.getMessage(); Assert.assertTrue( - "Loss function exception " + loss.toString() - + " did not indicate size mismatch when vectors of incorrect size were used.", - exceptionMessage.contains("shapes")); + "Loss function exception " + loss.toString() + + " did not indicate size mismatch when vectors of incorrect size were used.", + exceptionMessage.contains("shapes")); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java index 6f4213554..d22a83ad6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/AccountingTests.java @@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.AllocationsTracker; import org.nd4j.linalg.api.memory.DeviceAllocationsTracker; @@ -41,14 +42,12 @@ import static org.junit.jupiter.api.Assertions.*; @Slf4j @Disabled -@RunWith(Parameterized.class) -public class AccountingTests extends BaseNd4jTest { - public AccountingTests(Nd4jBackend backend) { - super(backend); - } +public class AccountingTests extends BaseNd4jTestWithBackends { @Test - public void testDetached_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDetached_1(Nd4jBackend backend) { val array = Nd4j.createFromArray(1, 2, 3, 4, 5); assertEquals(DataType.INT, array.dataType()); @@ -56,7 +55,9 @@ public class AccountingTests extends BaseNd4jTest { } @Test - public void testDetached_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDetached_2(Nd4jBackend backend) { val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); val before = Nd4j.getMemoryManager().allocatedMemory(deviceId); @@ -71,7 +72,9 @@ public class AccountingTests extends BaseNd4jTest { } @Test - public void testWorkspaceAccounting_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testWorkspaceAccounting_1(Nd4jBackend backend) { val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); val wsConf = WorkspaceConfiguration.builder() .initialSize(10 * 1024 * 1024) @@ -95,7 +98,9 @@ public class AccountingTests extends BaseNd4jTest { } @Test - public void testWorkspaceAccounting_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testWorkspaceAccounting_2(Nd4jBackend backend) { val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); val wsConf = WorkspaceConfiguration.builder() .initialSize(0) @@ -124,7 +129,9 @@ public class AccountingTests extends BaseNd4jTest { } @Test - public void testManualDeallocation_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testManualDeallocation_1(Nd4jBackend backend) { val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); val before = Nd4j.getMemoryManager().allocatedMemory(deviceId); @@ -143,7 +150,9 @@ public class AccountingTests extends BaseNd4jTest { } @Test - public void testTracker_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTracker_1(Nd4jBackend backend) { val tracker = new DeviceAllocationsTracker(); for (val e: AllocationKind.values()) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java index 5ded208b6..a7ceeca5d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.memory; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -34,14 +35,13 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class CloseableTests extends BaseNd4jTest { - public CloseableTests(Nd4jBackend backend) { - super(backend); - } + +public class CloseableTests extends BaseNd4jTestWithBackends { @Test - public void testSimpleRelease_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSimpleRelease_1(Nd4jBackend backend) { val array = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5}); assertTrue(array.closeable()); @@ -51,7 +51,9 @@ public class CloseableTests extends BaseNd4jTest { } @Test - public void testCyclicRelease_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCyclicRelease_1(Nd4jBackend backend) { for (int e = 0; e < 100; e++) { try (val array = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5})) { array.addi(1.0f); @@ -61,7 +63,9 @@ public class CloseableTests extends BaseNd4jTest { } @Test - public void testViewRelease_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testViewRelease_1(Nd4jBackend backend) { val array = Nd4j.create(5, 5); assertTrue(array.closeable()); @@ -72,7 +76,9 @@ public class CloseableTests extends BaseNd4jTest { } @Test - public void testAttachedRelease_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAttachedRelease_1(Nd4jBackend backend) { val wsconf = WorkspaceConfiguration.builder().build(); try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsconf, "haha72yjhfdfs")) { @@ -82,7 +88,9 @@ public class CloseableTests extends BaseNd4jTest { } @Test() - public void testAccessException_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAccessException_1(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { val array = Nd4j.create(5, 5); array.close(); @@ -93,7 +101,9 @@ public class CloseableTests extends BaseNd4jTest { } @Test() - public void testAccessException_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAccessException_2(Nd4jBackend backend) { assertThrows(IllegalStateException.class,() -> { val array = Nd4j.create(5, 5); val view = array.getRow(0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java index 9dc02b36a..0325187cc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/DeviceLocalNDArrayTests.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.memory; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -39,15 +40,14 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@RunWith(Parameterized.class) -public class DeviceLocalNDArrayTests extends BaseNd4jTest { - public DeviceLocalNDArrayTests(Nd4jBackend backend) { - super(backend); - } +public class DeviceLocalNDArrayTests extends BaseNd4jTestWithBackends { + @Test - public void testDeviceLocalStringArray(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDeviceLocalStringArray(Nd4jBackend backend){ val arr = Nd4j.create(Arrays.asList("first", "second"), 2); assertEquals(DataType.UTF8, arr.dataType()); assertArrayEquals(new long[]{2}, arr.shape()); @@ -61,7 +61,9 @@ public class DeviceLocalNDArrayTests extends BaseNd4jTest { } @Test - public void testDtypes(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDtypes(Nd4jBackend backend){ for(DataType globalDType : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){ Nd4j.setDefaultDataTypes(globalDType, globalDType); for(DataType arrayDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){ @@ -74,7 +76,9 @@ public class DeviceLocalNDArrayTests extends BaseNd4jTest { } @Test - public void testDeviceLocalUpdate_1() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDeviceLocalUpdate_1(Nd4jBackend backend) throws Exception { val numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); if (numDevices < 2) return; @@ -118,7 +122,9 @@ public class DeviceLocalNDArrayTests extends BaseNd4jTest { @Test - public void testDelayedDeviceLocalUpdate_1() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDelayedDeviceLocalUpdate_1(Nd4jBackend backend) throws Exception { val numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); if (numDevices < 2) return; @@ -145,7 +151,9 @@ public class DeviceLocalNDArrayTests extends BaseNd4jTest { } @Test - public void testDelayedDeviceLocalUpdate_2() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDelayedDeviceLocalUpdate_2(Nd4jBackend backend) throws Exception { val numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); if (numDevices < 2) return; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java index b60941bc6..54df2223d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java @@ -25,8 +25,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.graph.FlatArray; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.AllocationPolicy; @@ -51,11 +53,8 @@ import org.nd4j.nativeblas.NativeOpsHolder; import static org.junit.jupiter.api.Assertions.*; @Slf4j -public class MixedDataTypesTests extends BaseNd4jTest { +public class MixedDataTypesTests extends BaseNd4jTestWithBackends { - public MixedDataTypesTests(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -63,7 +62,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_1(Nd4jBackend backend) { val array = Nd4j.create(DataType.LONG, 3, 3); assertNotNull(array); @@ -73,7 +74,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_2(Nd4jBackend backend) { val array = Nd4j.create(DataType.SHORT, 3, 3); assertNotNull(array); @@ -83,7 +86,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_3(Nd4jBackend backend) { val array = Nd4j.create(DataType.HALF, 3, 3); assertNotNull(array); @@ -93,7 +98,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_4(Nd4jBackend backend) { val scalar = Nd4j.scalar(DataType.DOUBLE, 1.0); assertNotNull(scalar); assertEquals(0, scalar.rank()); @@ -103,7 +110,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_5(Nd4jBackend backend) { val scalar = Nd4j.scalar(Integer.valueOf(1)); assertNotNull(scalar); assertEquals(0, scalar.rank()); @@ -113,7 +122,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_5_0() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_5_0(Nd4jBackend backend) { val scalar = Nd4j.scalar(Long.valueOf(1)); assertNotNull(scalar); assertEquals(0, scalar.rank()); @@ -123,7 +134,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_5_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_5_1(Nd4jBackend backend) { val scalar = Nd4j.scalar(Double.valueOf(1)); assertNotNull(scalar); assertEquals(0, scalar.rank()); @@ -133,7 +146,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_5_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_5_2(Nd4jBackend backend) { val scalar = Nd4j.scalar(Float.valueOf(1)); assertNotNull(scalar); assertEquals(0, scalar.rank()); @@ -143,7 +158,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_5_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_5_3(Nd4jBackend backend) { val scalar = Nd4j.scalar(Short.valueOf((short) 1)); assertNotNull(scalar); assertEquals(0, scalar.rank()); @@ -153,7 +170,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_5_4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_5_4(Nd4jBackend backend) { val scalar = Nd4j.scalar(Byte.valueOf((byte) 1)); assertNotNull(scalar); assertEquals(0, scalar.rank()); @@ -163,7 +182,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_6(Nd4jBackend backend) { val scalar = Nd4j.scalar(1); assertNotNull(scalar); assertEquals(0, scalar.rank()); @@ -173,7 +194,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicCreation_7() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicCreation_7(Nd4jBackend backend) { val scalar = Nd4j.scalar(1L); assertNotNull(scalar); assertEquals(0, scalar.rank()); @@ -183,7 +206,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicOps_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicOps_1(Nd4jBackend backend) { val exp = new int[]{1,1,1,1,1,1,1,1,1}; val array = Nd4j.create(DataType.INT, 3, 3); assertEquals(DataType.INT, array.dataType()); @@ -194,7 +219,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicOps_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicOps_2(Nd4jBackend backend) { val exp = new int[]{1,1,1,1,1,1,1,1,1}; val arrayX = Nd4j.create(DataType.INT, 3, 3); val arrayY = Nd4j.create(new int[]{1,1,1,1,1,1,1,1,1}, new long[]{3, 3}, DataType.INT); @@ -206,7 +233,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicOps_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicOps_3(Nd4jBackend backend) { if (!NativeOpsHolder.getInstance().getDeviceNativeOps().isExperimentalEnabled()) return; @@ -224,7 +253,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicOps_4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicOps_4(Nd4jBackend backend) { val arrayX = Nd4j.create(new int[]{7,8,7,9,1,1,1,1,1}, new long[]{3, 3}, DataType.LONG); val result = arrayX.maxNumber(); @@ -234,7 +265,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicOps_5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicOps_5(Nd4jBackend backend) { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val result = arrayX.meanNumber().floatValue(); @@ -243,7 +276,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicOps_6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicOps_6(Nd4jBackend backend) { val arrayX = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.INT); val z = Nd4j.getExecutioner().exec(new CountNonZero(arrayX)); @@ -255,7 +290,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicOps_7() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicOps_7(Nd4jBackend backend) { val arrayX = Nd4j.create(new float[]{1, 0, Float.NaN, 4}, new long[]{4}, DataType.FLOAT); val z = Nd4j.getExecutioner().exec(new IsInf(arrayX)); @@ -271,7 +308,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicOps_8() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicOps_8(Nd4jBackend backend) { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.INT); val exp = new long[]{1, 0, 0, 1}; @@ -284,7 +323,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBasicOps_9() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicOps_9(Nd4jBackend backend) { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val arrayY = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val exp = new long[]{1, 0, 0, 1}; @@ -297,7 +338,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testNewAssign_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewAssign_1(Nd4jBackend backend) { val arrayX = Nd4j.create(DataType.FLOAT, 5); val arrayY = Nd4j.create(new double[]{1, 2, 3, 4, 5}); val exp = Nd4j.create(new float[]{1.f, 2.f, 3.f, 4.f, 5.f}); @@ -308,7 +351,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testNewAssign_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewAssign_2(Nd4jBackend backend) { val arrayX = Nd4j.create(DataType.INT, 5); val arrayY = Nd4j.create(new double[]{1, 2, 3, 4, 5}); val exp = Nd4j.create(new int[]{1, 2, 3, 4, 5}, new long[]{5}, DataType.INT); @@ -319,7 +364,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testMethods_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMethods_1(Nd4jBackend backend) { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val arrayY = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val exp = Nd4j.create(new int[]{2, 4, 6, 8}, new long[]{4}, DataType.INT); @@ -330,7 +377,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testMethods_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMethods_2(Nd4jBackend backend) { if (!NativeOpsHolder.getInstance().getDeviceNativeOps().isExperimentalEnabled()) return; @@ -345,7 +394,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testMethods_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMethods_3(Nd4jBackend backend) { if (!NativeOpsHolder.getInstance().getDeviceNativeOps().isExperimentalEnabled()) return; @@ -360,7 +411,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test() - public void testTypesValidation_1() { + public void testTypesValidation_1(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG); val arrayY = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); @@ -373,7 +424,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test() - public void testTypesValidation_2() { + public void testTypesValidation_2(Nd4jBackend backend) { assertThrows(RuntimeException.class,() -> { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.LONG); @@ -388,7 +439,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test() - public void testTypesValidation_3() { + public void testTypesValidation_3(Nd4jBackend backend) { assertThrows(RuntimeException.class,() -> { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); @@ -397,7 +448,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { } - public void testTypesValidation_4() { + public void testTypesValidation_4(Nd4jBackend backend) { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.DOUBLE); val arrayE = Nd4j.create(new int[]{2, 2, 3, 8}, new long[]{4}, DataType.INT); @@ -408,7 +459,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { @Test - public void testFlatSerde_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlatSerde_1(Nd4jBackend backend) { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val builder = new FlatBufferBuilder(512); @@ -424,7 +477,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testFlatSerde_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlatSerde_2(Nd4jBackend backend) { val arrayX = Nd4j.create(new long[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG); val builder = new FlatBufferBuilder(512); @@ -440,7 +495,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testFlatSerde_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlatSerde_3(Nd4jBackend backend) { val arrayX = Nd4j.create(new boolean[]{true, false, true, true}, new long[]{4}, DataType.BOOL); val builder = new FlatBufferBuilder(512); @@ -456,6 +513,8 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBoolFloatCast2(){ val first = Nd4j.zeros(DataType.FLOAT, 3, 5000); INDArray asBool = first.castTo(DataType.BOOL); @@ -476,7 +535,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testReduce3Large() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReduce3Large(Nd4jBackend backend) { val arrayX = Nd4j.create(DataType.FLOAT, 10, 5000); val arrayY = Nd4j.create(DataType.FLOAT, 10, 5000); @@ -485,6 +546,8 @@ public class MixedDataTypesTests extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testAssignScalarSimple(){ for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { INDArray arr = Nd4j.scalar(dt, 10.0); @@ -494,6 +557,8 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSimple(){ Nd4j.create(1); for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) { @@ -518,6 +583,8 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWorkspaceBool(){ val conf = WorkspaceConfiguration.builder().minSize(10 * 1024 * 1024) .overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE) @@ -543,7 +610,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { @Test @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") - public void testArrayCreationFromPointer() { + public void testArrayCreationFromPointer(Nd4jBackend backend) { val source = Nd4j.create(new double[]{1, 2, 3, 4, 5}); val pAddress = source.data().addressPointer(); @@ -561,7 +628,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testBfloat16_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBfloat16_1(Nd4jBackend backend) { val x = Nd4j.create(DataType.BFLOAT16, 5); val y = Nd4j.createFromArray(new int[]{2, 2, 2, 2, 2}).castTo(DataType.BFLOAT16); @@ -570,7 +639,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testUint16_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUint16_1(Nd4jBackend backend) { val x = Nd4j.create(DataType.UINT16, 5); val y = Nd4j.createFromArray(new int[]{2, 2, 2, 2, 2}).castTo(DataType.UINT16); @@ -579,7 +650,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testUint32_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUint32_1(Nd4jBackend backend) { val x = Nd4j.create(DataType.UINT32, 5); val y = Nd4j.createFromArray(new int[]{2, 2, 2, 2, 2}).castTo(DataType.UINT32); @@ -588,7 +661,9 @@ public class MixedDataTypesTests extends BaseNd4jTest { } @Test - public void testUint64_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUint64_1(Nd4jBackend backend) { val x = Nd4j.create(DataType.UINT64, 5); val y = Nd4j.createFromArray(new int[]{2, 2, 2, 2, 2}).castTo(DataType.UINT64); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java index c14020d22..79268a06a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java @@ -24,8 +24,10 @@ import com.google.flatbuffers.FlatBufferBuilder; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.graph.FlatArray; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -33,11 +35,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.*; @Slf4j -public class StringArrayTests extends BaseNd4jTest { +public class StringArrayTests extends BaseNd4jTestWithBackends { - public StringArrayTests(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -45,7 +44,9 @@ public class StringArrayTests extends BaseNd4jTest { } @Test - public void testBasicStrings_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicStrings_1(Nd4jBackend backend) { val array = Nd4j.scalar("alpha"); assertNotNull(array); @@ -60,7 +61,9 @@ public class StringArrayTests extends BaseNd4jTest { } @Test - public void testBasicStrings_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicStrings_2(Nd4jBackend backend) { val array = Nd4j.create("alpha","beta", "gamma"); assertNotNull(array); @@ -79,6 +82,8 @@ public class StringArrayTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBasicStrings_3() { val arrayX = Nd4j.create("alpha", "beta", "gamma"); val arrayY = Nd4j.create("alpha", "beta", "gamma"); @@ -90,6 +95,8 @@ public class StringArrayTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBasicStrings_4() { val arrayX = Nd4j.create("alpha", "beta", "gamma"); @@ -108,6 +115,8 @@ public class StringArrayTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBasicStrings_4a() { val arrayX = Nd4j.scalar("alpha"); @@ -126,6 +135,8 @@ public class StringArrayTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBasicStrings_5() { val arrayX = Nd4j.create("alpha", "beta", "gamma"); val arrayZ0 = arrayX.dup(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java index 821499afc..fc0c780a4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/multithreading/MultithreadedTests.java @@ -22,22 +22,19 @@ package org.nd4j.linalg.multithreading; import lombok.val; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; import java.util.ArrayList; import java.util.HashSet; import static org.junit.jupiter.api.Assertions.assertEquals; -public class MultithreadedTests extends BaseNd4jTest { - - public MultithreadedTests(Nd4jBackend backend) { - super(backend); - } +public class MultithreadedTests extends BaseNd4jTestWithBackends { @Override public char ordering() { @@ -45,6 +42,8 @@ public class MultithreadedTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void basicMigrationTest_1() throws Exception { if (Nd4j.getAffinityManager().getNumberOfDevices() < 2) return; @@ -57,21 +56,18 @@ public class MultithreadedTests extends BaseNd4jTest { val list = new ArrayList(); for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { val t = e; - val thread = new Thread(new Runnable() { - @Override - public void run() { - for (int f = 0; f < 10; f++) { - val array = Nd4j.create(DataType.INT32, 5, 5).assign(1); + val thread = new Thread(() -> { + for (int f = 0; f < 10; f++) { + val array = Nd4j.create(DataType.INT32, 5, 5).assign(1); - // store current deviceId for further validation - hash.add(Nd4j.getAffinityManager().getDeviceForCurrentThread()); + // store current deviceId for further validation + hash.add(Nd4j.getAffinityManager().getDeviceForCurrentThread()); - // make sure INDArray has proper affinity set - assertEquals(Nd4j.getAffinityManager().getDeviceForCurrentThread(), Nd4j.getAffinityManager().getDeviceForArray(array)); + // make sure INDArray has proper affinity set + assertEquals(Nd4j.getAffinityManager().getDeviceForCurrentThread(), Nd4j.getAffinityManager().getDeviceForArray(array)); - list.add(array); - } - }; + list.add(array); + } }); thread.start(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java index e5752cff4..86801bb18 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java @@ -25,7 +25,9 @@ import lombok.val; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -34,26 +36,25 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -public class NativeBlasTests extends BaseNd4jTest { +public class NativeBlasTests extends BaseNd4jTestWithBackends { - public NativeBlasTests(Nd4jBackend backend) { - super(backend); - } @BeforeEach - public void setUp() { + public void setUp(Nd4jBackend backend) { Nd4j.getExecutioner().enableDebugMode(true); Nd4j.getExecutioner().enableVerboseMode(true); } @AfterEach - public void setDown() { + public void setDown(Nd4jBackend backend) { Nd4j.getExecutioner().enableDebugMode(false); Nd4j.getExecutioner().enableVerboseMode(false); } @Test - public void testBlasGemm1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasGemm1(Nd4jBackend backend) { // we're skipping blas here if (Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) @@ -79,7 +80,9 @@ public class NativeBlasTests extends BaseNd4jTest { @Test - public void testBlasGemm2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasGemm2(Nd4jBackend backend) { // we're skipping blas here if (Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) @@ -105,7 +108,9 @@ public class NativeBlasTests extends BaseNd4jTest { @Test - public void testBlasGemm3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasGemm3(Nd4jBackend backend) { // we're skipping blas here if (Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) @@ -131,7 +136,9 @@ public class NativeBlasTests extends BaseNd4jTest { @Test - public void testBlasGemm4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasGemm4(Nd4jBackend backend) { // we're skipping blas here if (Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) @@ -157,7 +164,9 @@ public class NativeBlasTests extends BaseNd4jTest { @Test - public void testBlasGemm5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasGemm5(Nd4jBackend backend) { // we're skipping blas here if (Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) @@ -182,7 +191,9 @@ public class NativeBlasTests extends BaseNd4jTest { } @Test - public void testBlasGemm6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasGemm6(Nd4jBackend backend) { // we're skipping blas here if (Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) @@ -208,7 +219,9 @@ public class NativeBlasTests extends BaseNd4jTest { @Test - public void testBlasGemm7() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasGemm7(Nd4jBackend backend) { // we're skipping blas here if (Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) @@ -236,7 +249,9 @@ public class NativeBlasTests extends BaseNd4jTest { @Test - public void testBlasGemv1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasGemv1(Nd4jBackend backend) { // we're skipping blas here if (Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) @@ -264,7 +279,9 @@ public class NativeBlasTests extends BaseNd4jTest { @Test - public void testBlasGemv2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasGemv2(Nd4jBackend backend) { // we're skipping blas here if (Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) @@ -292,7 +309,9 @@ public class NativeBlasTests extends BaseNd4jTest { @Test - public void testBlasGemv3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasGemv3(Nd4jBackend backend) { // we're skipping blas here if (Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java index c2dca756e..df38ab49d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java @@ -24,10 +24,12 @@ import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ops.BaseBroadcastOp; import org.nd4j.linalg.api.ops.BaseIndexAccumulation; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; @@ -53,11 +55,8 @@ import java.util.List; import java.util.Set; @Slf4j -public class OpsMappingTests extends BaseNd4jTest { +public class OpsMappingTests extends BaseNd4jTestWithBackends { - public OpsMappingTests(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -70,7 +69,9 @@ public class OpsMappingTests extends BaseNd4jTest { } @Test - public void testLegacyOpsMapping() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLegacyOpsMapping(Nd4jBackend backend) { Nd4j.create(1); val str = NativeOpsHolder.getInstance().getDeviceNativeOps().getAllOperations().replaceAll("simdOps::","").replaceAll("randomOps::",""); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java index f623847e1..8420cff48 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java @@ -24,9 +24,10 @@ import org.apache.commons.math3.util.FastMath; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.scalar.Step; @@ -45,18 +46,13 @@ import org.nd4j.linalg.ops.transforms.Transforms; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class DerivativeTests extends BaseNd4jTest { + +public class DerivativeTests extends BaseNd4jTestWithBackends { public static final double REL_ERROR_TOLERANCE = 1e-3; - DataType initialType; - - public DerivativeTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } + DataType initialType = Nd4j.dataType(); @BeforeEach public void before() { @@ -69,7 +65,9 @@ public class DerivativeTests extends BaseNd4jTest { } @Test - public void testHardTanhDerivative() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHardTanhDerivative(Nd4jBackend backend) { //HardTanh: //f(x) = 1 if x > 1 //f(x) = -1 if x < -1 @@ -95,7 +93,9 @@ public class DerivativeTests extends BaseNd4jTest { } @Test - public void testRectifiedLinearDerivative() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRectifiedLinearDerivative(Nd4jBackend backend) { //ReLU: //f(x) = max(0,x) //Piecewise differentiable; choose f'(0) = 0 @@ -118,7 +118,9 @@ public class DerivativeTests extends BaseNd4jTest { } @Test - public void testSigmoidDerivative() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSigmoidDerivative(Nd4jBackend backend) { //Derivative of sigmoid: ds(x)/dx = s(x)*(1-s(x)) //s(x) = 1 / (exp(-x) + 1) INDArray z = Nd4j.zeros(100); @@ -141,7 +143,9 @@ public class DerivativeTests extends BaseNd4jTest { @Test - public void testHardSigmoidDerivative() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHardSigmoidDerivative(Nd4jBackend backend) { /* f(x) = min(1, max(0, 0.2*x + 0.5)) or equivalently: clip 0.2*x+0.5 to range 0 to 1 @@ -194,7 +198,9 @@ public class DerivativeTests extends BaseNd4jTest { @Test - public void testSoftPlusDerivative() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftPlusDerivative(Nd4jBackend backend) { //s(x) = 1 / (exp(-x) + 1) INDArray z = Nd4j.zeros(100); double[] expOut = new double[100]; @@ -214,7 +220,9 @@ public class DerivativeTests extends BaseNd4jTest { } @Test - public void testTanhDerivative() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTanhDerivative(Nd4jBackend backend) { //Derivative of sigmoid: ds(x)/dx = s(x)*(1-s(x)) //s(x) = 1 / (exp(-x) + 1) @@ -237,7 +245,9 @@ public class DerivativeTests extends BaseNd4jTest { } @Test - public void testCubeDerivative() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCubeDerivative(Nd4jBackend backend) { //Derivative of cube: 3*x^2 INDArray z = Nd4j.zeros(100); @@ -262,7 +272,9 @@ public class DerivativeTests extends BaseNd4jTest { } @Test - public void testLeakyReLUDerivative() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLeakyReLUDerivative(Nd4jBackend backend) { //Derivative: 0.01 if x<0, 1 otherwise INDArray z = Nd4j.zeros(100); double[] expOut = new double[100]; @@ -282,7 +294,9 @@ public class DerivativeTests extends BaseNd4jTest { } @Test - public void testSoftSignDerivative() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftSignDerivative(Nd4jBackend backend) { //Derivative: 1 / (1+abs(x))^2 INDArray z = Nd4j.zeros(100).castTo(DataType.DOUBLE); double[] expOut = new double[100]; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java index 1addead6c..ad19795d6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpConstructorTests.java @@ -22,9 +22,11 @@ package org.nd4j.linalg.ops; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ops.NoOp; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.lossfunctions.ILossFunction; @@ -41,11 +43,7 @@ import java.util.*; import static org.junit.jupiter.api.Assertions.assertEquals; @Disabled //AB 2019/08/23 Ignored for now -public class OpConstructorTests extends BaseNd4jTest { - - public OpConstructorTests(Nd4jBackend backend) { - super(backend); - } +public class OpConstructorTests extends BaseNd4jTestWithBackends { //Ignore individual classes protected Set> exclude = new HashSet<>( @@ -60,7 +58,9 @@ public class OpConstructorTests extends BaseNd4jTest { }; @Test - public void checkForINDArrayConstructors() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void checkForINDArrayConstructors(Nd4jBackend backend) throws Exception { /* Check that all op classes have at least one INDArray or INDArray[] constructor, so they can actually be used outside of SameDiff @@ -109,12 +109,7 @@ public class OpConstructorTests extends BaseNd4jTest { } if(!classes.isEmpty()){ - Collections.sort(classes, new Comparator>() { - @Override - public int compare(Class o1, Class o2) { - return o1.getName().compareTo(o2.getName()); - } - }); + Collections.sort(classes, Comparator.comparing(Class::getName)); for(Class c : classes){ System.out.println("No INDArray constructor: " + c.getName()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java index 1b9de3efa..2ac5dc02a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.ops; import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; @@ -67,18 +68,13 @@ import java.util.concurrent.Executors; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class OpExecutionerTests extends BaseNd4jTest { - - - public OpExecutionerTests(Nd4jBackend backend) { - super(backend); - } - +public class OpExecutionerTests extends BaseNd4jTestWithBackends { @Test - public void testCosineSimilarity() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCosineSimilarity(Nd4jBackend backend) { INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); double sim = Transforms.cosineSim(vec1, vec2); @@ -87,6 +83,8 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCosineDistance(){ INDArray vec1 = Nd4j.create(new float[] {1, 2, 3}); INDArray vec2 = Nd4j.create(new float[] {3, 5, 7}); @@ -96,7 +94,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testEuclideanDistance() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEuclideanDistance(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[] {55, 55}); INDArray arr2 = Nd4j.create(new double[] {60, 60}); double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).z().getDouble(0); @@ -104,7 +104,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testDimensionalEuclidean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDimensionalEuclidean(Nd4jBackend backend) { INDArray distanceInputRow = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1); INDArray distanceComp = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).add(1); INDArray result = Nd4j.createUninitialized(DataType.DOUBLE, 4); @@ -124,12 +126,9 @@ public class OpExecutionerTests extends BaseNd4jTest { INDArray rowVector = matrix.getRow(70); INDArray resultArr = Nd4j.zeros(400,1); Executor executor = Executors.newSingleThreadExecutor(); - executor.execute(new Runnable() { - @Override - public void run() { - Nd4j.getExecutioner().exec(new EuclideanDistance(matrix, rowVector, resultArr, -1)); - System.out.println("Ran!"); - } + executor.execute(() -> { + Nd4j.getExecutioner().exec(new EuclideanDistance(matrix, rowVector, resultArr, -1)); + System.out.println("Ran!"); }); Thread.sleep(600000); @@ -137,7 +136,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testScalarMaxOp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarMaxOp(Nd4jBackend backend) { INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi(); INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6); Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1)); @@ -145,7 +146,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testSetRange() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSetRange(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1)); for (int i = 0; i < linspace.length(); i++) { @@ -162,14 +165,18 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testNormMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNormMax(Nd4jBackend backend) { INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).z().getDouble(0); assertEquals(4, normMax, 1e-1,getFailureMessage()); } @Test - public void testLog() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLog(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray assertion = Nd4j.create(new double[][] {{0., 1.09861229}, {0.69314718, 1.38629436}}); @@ -184,14 +191,18 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testNorm2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm2(Nd4jBackend backend) { INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).z().getDouble(0); assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage()); } @Test - public void testAdd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdd(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.ones(5); INDArray xDup = x.dup(); @@ -201,7 +212,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testMul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMul(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.ones(5); INDArray xDup = x.dup(); @@ -212,7 +225,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testExecutioner() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExecutioner(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.ones(5); INDArray xDup = x.dup(); @@ -229,7 +244,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testMaxMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMaxMin(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); Max max = new Max(x); @@ -241,7 +258,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testProd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testProd(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); Prod prod = new Prod(linspace); double prod2 = Nd4j.getExecutioner().execAndReturn(prod).z().getDouble(0); @@ -249,7 +268,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); Sum sum = new Sum(linspace); double sum2 = Nd4j.getExecutioner().execAndReturn(sum).z().getDouble(0); @@ -258,7 +279,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testDescriptiveStatsDouble() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDescriptiveStatsDouble(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); @@ -274,13 +297,17 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testIamax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIamax(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage()); } @Test - public void testIamax2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIamax2(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage()); val op = new ArgAmax(linspace); @@ -291,7 +318,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testDescriptiveStats() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDescriptiveStats(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); @@ -305,7 +334,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testRowSoftmax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRowSoftmax(Nd4jBackend backend) { val opExecutioner = Nd4j.getExecutioner(); val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); @@ -315,7 +346,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testPow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPow(Nd4jBackend backend) { INDArray oneThroughSix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); Pow pow = new Pow(oneThroughSix, 2); Nd4j.getExecutioner().exec(pow); @@ -325,7 +358,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testComparisonOps() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testComparisonOps(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); INDArray ones = Nd4j.ones(DataType.BOOL, 6); INDArray zeros = Nd4j.zeros(DataType.BOOL, 6); @@ -337,7 +372,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testScalarArithmetic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarArithmetic(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); INDArray plusOne = Nd4j.linspace(2, 7, 6, DataType.DOUBLE); Nd4j.getExecutioner().exec(new ScalarAdd(linspace, 1)); @@ -346,7 +383,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testDimensionMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDimensionMax(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); int axis = 0; INDArray row = linspace.slice(axis); @@ -361,7 +400,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testStridedLog() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedLog(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray slice = arr.slice(0); @@ -372,7 +413,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testSoftmax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftmax(Nd4jBackend backend) { INDArray vec = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); INDArray matrix = vec.dup().reshape('f', 2, 3); Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix)); @@ -383,7 +426,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testOtherSoftmax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOtherSoftmax(Nd4jBackend backend) { INDArray vec = Nd4j.linspace(1, 18, 18, DataType.DOUBLE); INDArray matrix = vec.dup().reshape('f', 3, 6); Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix)); @@ -396,7 +441,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testClassificationSoftmax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testClassificationSoftmax(Nd4jBackend backend) { INDArray input = Nd4j.create(new double[] {-0.11537042, -0.12137824, -0.12023379, -0.121212654, -0.11363918, -0.10101747, -0.11571036, -0.11699755, -0.12303393, -0.12222538, -0.111205295, -0.11710347, -0.12319956, -0.12442437, -0.10528548, -0.08768979, -0.102969095, -0.11346512, -0.106075466, @@ -527,7 +574,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testAddBroadcast() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddBroadcast(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape('f', 2, 3); INDArray arrRow = Nd4j.create(new double[] {1, 2, 3}); INDArray assertion = Nd4j.create(new double[] {2, 3, 5, 6, 8, 9}, new int[] {2, 3}, 'f'); @@ -542,7 +591,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testStridedExp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedExp(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray slice = arr.slice(0); @@ -555,7 +606,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testSoftMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftMax(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); @@ -564,7 +617,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testIMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMax(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); ArgMax imax = new ArgMax(arr); assertEquals(9, Nd4j.getExecutioner().exec(imax)[0].getInt(0)); @@ -576,7 +631,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testIMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMin(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); ArgMin imin = new ArgMin(arr); assertEquals(0, Nd4j.getExecutioner().exec(imin)[0].getInt(0)); @@ -589,7 +646,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testMeanSumSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMeanSumSimple(Nd4jBackend backend) { // System.out.println("3d"); INDArray arr = Nd4j.ones(1, 4, 4); assertEquals(Nd4j.ones(1), arr.mean(1, 2)); @@ -626,7 +685,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void tescodtSum6d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void tescodtSum6d(Nd4jBackend backend) { INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4); INDArray arr6s = arr6.sum(2, 3); @@ -636,7 +697,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testSum6d2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum6d2(Nd4jBackend backend) { char origOrder = Nd4j.order(); try { for (char order : new char[]{'c', 'f'}) { @@ -673,7 +736,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testMean6d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMean6d(Nd4jBackend backend) { INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4); INDArray arr6m = arr6.mean(2, 3); @@ -691,7 +756,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testStdev() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStdev(Nd4jBackend backend) { INDArray arr = Nd4j.create(new float[] {0.9296161f, 0.31637555f, 0.1839188f}, new int[] {1, 3}, ordering()); double stdev = arr.stdNumber(true).doubleValue(); @@ -706,7 +773,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testVariance() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariance(Nd4jBackend backend) { val f = new double[] {0.9296161, 0.31637555, 0.1839188}; INDArray arr = Nd4j.create(f, new int[] {1, 3}, ordering()); double var = arr.varNumber().doubleValue(); @@ -721,7 +790,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testDropout() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDropout(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 100, 100, DataType.DOUBLE); INDArray result = Nd4j.create(DataType.DOUBLE, 100); @@ -735,7 +806,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testDropoutInverted() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDropoutInverted(Nd4jBackend backend) { INDArray array = Nd4j.linspace(1, 100, 100, DataType.DOUBLE); INDArray result = Nd4j.create(DataType.DOUBLE, 100); @@ -749,7 +822,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testVPull1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVPull1(Nd4jBackend backend) { int indexes[] = new int[] {0, 2, 4}; INDArray array = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(5, 5); INDArray assertion = Nd4j.createUninitialized(DataType.DOUBLE, new long[] {3, 5}, 'f'); @@ -765,7 +840,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testVPull2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVPull2(Nd4jBackend backend) { int indexes[] = new int[] {0, 2, 4}; INDArray array = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(5, 5); INDArray assertion = Nd4j.createUninitialized(DataType.DOUBLE, new long[] {3, 5}, 'c'); @@ -785,7 +862,9 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test - public void testPile1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPile1(Nd4jBackend backend) { List arrays = new ArrayList<>(); for (int i = 0; i < 10; i++) { arrays.add(Nd4j.create(10, 10).assign(i)); @@ -800,7 +879,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testPile2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPile2(Nd4jBackend backend) { List arrays = new ArrayList<>(); for (int i = 0; i < 10; i++) { arrays.add(Nd4j.create(10, 10, 10).assign(i)); @@ -815,7 +896,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testPile3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPile3(Nd4jBackend backend) { List arrays = new ArrayList<>(); for (int i = 0; i < 10; i++) { arrays.add(Nd4j.create( 10, 10).assign(i)); @@ -830,7 +913,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testPile4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPile4(Nd4jBackend backend) { val arrayW = Nd4j.create(1, 5); val arrayX = Nd4j.create(1, 5); val arrayY = Nd4j.create(1, 5); @@ -841,7 +926,9 @@ public class OpExecutionerTests extends BaseNd4jTest { } @Test - public void testTear1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTear1(Nd4jBackend backend) { List arrays = new ArrayList<>(); for (int i = 0; i < 10; i++) { arrays.add(Nd4j.create(10, 10).assign(i)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index d52c39755..865bd81d9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -25,9 +25,10 @@ import lombok.val; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; @@ -78,29 +79,25 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.point; @Slf4j -@RunWith(Parameterized.class) -public class OpExecutionerTestsC extends BaseNd4jTest { - public OpExecutionerTestsC(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } - - DataType initialType; +public class OpExecutionerTestsC extends BaseNd4jTestWithBackends { + DataType initialType = Nd4j.dataType(); @AfterEach - public void after() { + public void after(Nd4jBackend backend) { Nd4j.setDataType(this.initialType); } @Test - public void testSoftmaxReference() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftmaxReference(Nd4jBackend backend) { INDArray input = Nd4j.linspace(1,4,4, DataType.FLOAT).reshape(2,2); INDArray dup = input.dup(); - Nd4j.getExecutioner().exec((CustomOp) new SoftMax(dup)); + Nd4j.getExecutioner().exec(new SoftMax(dup)); INDArray result = Nd4j.zeros(DataType.FLOAT, 2,2); - Nd4j.getExecutioner().exec((CustomOp) new SoftMax(input,result)); + Nd4j.getExecutioner().exec(new SoftMax(input,result)); assertEquals(dup,result); @@ -114,7 +111,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testScalarReverseSub() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarReverseSub(Nd4jBackend backend) { INDArray input = Nd4j.valueArrayOf(4,2.0); INDArray result= Nd4j.zeros(4); Nd4j.getExecutioner().exec(new ScalarReverseSubtraction(input,null,result,1.0)); @@ -124,20 +123,24 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testBroadcastMultiDim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBroadcastMultiDim(Nd4jBackend backend) { INDArray data = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(2, 3, 5); // System.out.println(data); INDArray mask = Nd4j.create(new double[][] {{1.00, 1.00, 1.00, 1.00, 1.00}, {1.00, 1.00, 1.00, 0.00, 0.00}}); Nd4j.getExecutioner().exec(new BroadcastMulOp(data, mask, data, 0, 2)); INDArray assertion = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 0.0, 0.0, 21.0, 22.0, 23.0, 0.0, 0.0, 26.0, 27.0, 28.0, 0.0, - 0.0}).reshape(2, 3, 5); + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 0.0, 0.0, 21.0, 22.0, 23.0, 0.0, 0.0, 26.0, 27.0, 28.0, 0.0, + 0.0}).reshape(2, 3, 5); assertEquals(assertion, data); } @Test - public void testCosineSimilarity() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCosineSimilarity(Nd4jBackend backend) { INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); double sim = Transforms.cosineSim(vec1, vec2); @@ -145,6 +148,8 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCosineDistance(){ INDArray vec1 = Nd4j.create(new float[] {1, 2, 3}); INDArray vec2 = Nd4j.create(new float[] {3, 5, 7}); @@ -154,7 +159,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testLog() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLog(Nd4jBackend backend) { INDArray log = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); INDArray transformed = Transforms.log(log); INDArray assertion = Nd4j.create(new double[] {0., 0.69314718, 1.09861229, 1.38629436, 1.60943791, 1.79175947}); @@ -162,7 +169,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testNorm1AlongDimension() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm1AlongDimension(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(2, 4); INDArray arrNorm1 = arr.norm2(1); INDArray assertion = Nd4j.create(new double[] {5.47722558, 13.19090596}); @@ -171,16 +180,20 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testEuclideanDistance() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEuclideanDistance(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[] {55, 55}); INDArray arr2 = Nd4j.create(new double[] {60, 60}); double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).getFinalResult() - .doubleValue(); + .doubleValue(); assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage()); } @Test - public void testScalarMaxOp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarMaxOp(Nd4jBackend backend) { INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi(); INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6); Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1)); @@ -188,7 +201,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testSetRange() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSetRange(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1)); for (int i = 0; i < linspace.length(); i++) { @@ -206,7 +221,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testNormMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNormMax(Nd4jBackend backend) { INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).getFinalResult().doubleValue(); assertEquals(4, normMax, 1e-1,getFailureMessage()); @@ -214,14 +231,18 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testNorm2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm2(Nd4jBackend backend) { INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).getFinalResult().doubleValue(); assertEquals( 5.4772255750516612, norm2, 1e-1,getFailureMessage()); } @Test - public void testAdd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAdd(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.ones(5); INDArray xDup = x.dup(); @@ -231,7 +252,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testMul() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMul(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.ones(5); INDArray xDup = x.dup(); @@ -242,7 +265,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testExecutioner() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExecutioner(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.ones(5); INDArray xDup = x.dup(); @@ -259,7 +284,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testMaxMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMaxMin(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); Max max = new Max(x); @@ -271,7 +298,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testProd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testProd(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); Prod prod = new Prod(linspace); double prod2 = Nd4j.getExecutioner().execAndReturn(prod).getFinalResult().doubleValue(); @@ -279,7 +308,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); Sum sum = new Sum(linspace); double sum2 = Nd4j.getExecutioner().execAndReturn(sum).getFinalResult().doubleValue(); @@ -292,7 +323,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testDescriptiveStatsDouble() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDescriptiveStatsDouble(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); @@ -307,7 +340,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testDescriptiveStats() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDescriptiveStats(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray x = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); @@ -321,7 +356,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testRowSoftmax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRowSoftmax(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); @@ -330,7 +367,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testAddiRowVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddiRowVector(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray arr2 = Nd4j.linspace(1, 3, 3, DataType.DOUBLE); INDArray assertion = Nd4j.create(new double[] {2, 4, 6, 5, 7, 9}).reshape(2, 3); @@ -339,7 +378,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testTad() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTad(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(2, 3, 2); for (int i = 0; i < arr.tensorsAlongDimension(0); i++) { // System.out.println(arr.tensorAlongDimension(i, 0)); @@ -349,7 +390,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testPow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPow(Nd4jBackend backend) { INDArray oneThroughSix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); Pow pow = new Pow(oneThroughSix, 2); Nd4j.getExecutioner().exec(pow); @@ -359,7 +402,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testComparisonOps() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testComparisonOps(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); INDArray ones = Nd4j.ones(DataType.BOOL, 1,6); INDArray zeros = Nd4j.create(DataType.BOOL, 1,6); @@ -371,7 +416,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testScalarArithmetic() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarArithmetic(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); INDArray plusOne = Nd4j.linspace(2, 7, 6, DataType.DOUBLE); Nd4j.getExecutioner().exec(new ScalarAdd(linspace, 1)); @@ -379,7 +426,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testDimensionMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDimensionMax(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); int axis = 0; INDArray row = linspace.slice(axis); @@ -397,7 +446,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testStridedLog() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedLog(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray slice = arr.slice(0); @@ -408,7 +459,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testStridedExp() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStridedExp(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray slice = arr.slice(0); @@ -421,7 +474,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testSoftMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftMax(Nd4jBackend backend) { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val softMax = new SoftMax(arr); @@ -436,7 +491,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testDimensionSoftMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDimensionSoftMax(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); val max = new SoftMax(linspace); Nd4j.getExecutioner().exec(max); @@ -445,7 +502,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testColumnMean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testColumnMean(Nd4jBackend backend) { INDArray twoByThree = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray columnMean = twoByThree.mean(0); INDArray assertion = Nd4j.create(new double[] {2, 3}); @@ -454,7 +513,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testColumnVar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testColumnVar(Nd4jBackend backend) { INDArray twoByThree = Nd4j.linspace(1, 600, 600, DataType.DOUBLE).reshape(150, 4); INDArray columnStd = twoByThree.var(0); INDArray assertion = Nd4j.create(new double[] {30200f, 30200f, 30200f, 30200f}); @@ -462,7 +523,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testColumnStd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testColumnStd(Nd4jBackend backend) { INDArray twoByThree = Nd4j.linspace(1, 600, 600, DataType.DOUBLE).reshape(150, 4); INDArray columnStd = twoByThree.std(0); INDArray assertion = Nd4j.create(new double[] {173.78147196982766f, 173.78147196982766f, 173.78147196982766f, 173.78147196982766f}); @@ -470,14 +533,18 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testDim1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDim1(Nd4jBackend backend) { INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1); INDArray same = sum.dup(); assertEquals(same.sum(1), sum.reshape(2)); } @Test - public void testIMax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMax(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); ArgMax imax = new ArgMax(arr); assertEquals(9, Nd4j.getExecutioner().exec(imax)[0].getInt(0)); @@ -489,7 +556,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testIMin() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIMin(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); ArgMin imin = new ArgMin(arr); assertEquals(0, Nd4j.getExecutioner().exec(imin)[0].getInt(0)); @@ -503,7 +572,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testMeanSumSimple() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMeanSumSimple(Nd4jBackend backend) { // System.out.println("3d"); INDArray arr = Nd4j.ones(1, 4, 4); assertEquals(Nd4j.ones(1), arr.mean(1, 2)); @@ -539,7 +610,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testSum6d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum6d(Nd4jBackend backend) { INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4); INDArray arr6s = arr6.sum(2, 3); for (int i = 0; i < arr6s.length(); i++) @@ -548,7 +621,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testMean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMean(Nd4jBackend backend) { int[] shape = new int[] {1, 2, 2, 2, 2, 2}; int len = ArrayUtil.prod(shape); INDArray val = Nd4j.linspace(1, len, len, DataType.DOUBLE).reshape('c', shape); @@ -590,6 +665,8 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSum5d() throws Exception { // System.out.println("5d"); INDArray arr5 = Nd4j.ones(1, 1, 4, 4, 4); @@ -606,7 +683,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testOneMinus() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOneMinus(Nd4jBackend backend) { INDArray in = Nd4j.linspace(1, 3, 3, DataType.DOUBLE); INDArray out = Transforms.timesOneMinus(in, true); @@ -618,18 +697,22 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testSubColumnVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSubColumnVector(Nd4jBackend backend) { INDArray vec = Nd4j.linspace(1, 18, 18, DataType.DOUBLE); INDArray matrix = vec.dup().reshape(3, 6); INDArray vector = Nd4j.create(new double[] {6, 12, 18}).reshape(3, 1); INDArray assertion = Nd4j.create(new double[] {-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, -5.0, -4.0, -3.0, -2.0, -1.0, - 0.0, -5.0, -4.0, -3.0, -2.0, -1.0, 0.0}, new int[] {3, 6}); + 0.0, -5.0, -4.0, -3.0, -2.0, -1.0, 0.0}, new int[] {3, 6}); INDArray test = matrix.subColumnVector(vector); assertEquals(assertion, test); } @Test - public void testLogSoftmaxVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogSoftmaxVector(Nd4jBackend backend) { INDArray temp = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}); INDArray logsoftmax = Nd4j.getExecutioner().exec(new LogSoftMax(temp.dup()))[0]; INDArray assertion = Nd4j.create(new double[] {-3.4401898, -2.4401898, -1.4401897, -0.44018975}); @@ -639,7 +722,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testSumDifferentOrder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSumDifferentOrder(Nd4jBackend backend) { INDArray toAssign = Nd4j.linspace(0, 3, 4, DataType.DOUBLE).reshape(2, 2); INDArray cOrder = Nd4j.create(new int[] {2, 2}, 'c').assign(toAssign); INDArray fOrder = Nd4j.create(new int[] {2, 2}, 'f').assign(toAssign); @@ -653,127 +738,129 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testLogSoftmax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogSoftmax(Nd4jBackend backend) { INDArray test = Nd4j.create(new double[] {-0.115370326, -0.12137828, -0.120233774, -0.12121266, -0.11363905, - -0.101017155, -0.11571029, -0.116997495, -0.123033985, -0.1222254, -0.11120513, -0.11710341, - -0.12319958, -0.124424405, -0.105285235, -0.08768927, -0.10296882, -0.11346505, -0.10607526, - -0.10681274, -0.11604863, -0.1070115, -0.114202365, -0.11168295, -0.11615404, -0.120522454, - -0.11282451, -0.11514864, -0.11681116, -0.11987897, -0.12054029, -0.112625614, -0.10337835, - -0.098809384, -0.1222254, -0.11966098, -0.11500366, -0.1222254, -0.122691356, -0.1168594, - -0.11369472, -0.11666928, -0.12075868, -0.10658686, -0.10251844, -0.119958505, -0.10873747, - -0.12036781, -0.11125211, -0.118474, 0.07354958, 0.06268418, 0.08751996, 0.05259535, 0.07969022, - 0.062334962, 0.07089297, -0.006484107, 0.0702586, 0.03601057, 0.03228142, 0.051330067, - 0.048092633, 0.0753836, 0.0026741663, 0.060346458, 0.064265735, 0.03208362, 0.07322607, - 0.034286126, 0.08459597, 0.040570714, 0.08494339, 0.06835921, 0.055334114, 0.06346921, - 0.08284429, 0.09769646, 0.07128828, 0.0012985547, 0.033257447, 0.024084045, 0.03130147, - 0.09381818, 0.062283173, 0.049273495, 0.0789609, 0.06648661, 0.030163772, 0.047266945, - 0.05704684, 0.06862679, 0.04134995, 0.0029913357, 0.050757334, 0.031863946, 0.043180045, - 0.053592253, -0.02633951, 0.04229047, 0.12401424, 0.1025523, 0.11914653, 0.10838079, - 0.119204566, 0.120582364, 0.079642124, 0.1136303, 0.103594445, 0.12434465, 0.10481718, - 0.10615024, 0.1161067, 0.101516, 0.11543929, 0.11498181, 0.1083647, 0.12498043, 0.117732316, - 0.080594465, 0.12140614, 0.10168964, 0.11630502, 0.097365364, 0.11659742, 0.11525785, - 0.095346555, 0.095523514, 0.1145297, 0.10820676, 0.113681756, 0.12088448, 0.11661095, - 0.09196416, 0.09367608, 0.12396194, 0.11715822, 0.10781161, 0.09206241, 0.11529953, 0.12193694, - 0.11471913, 0.1025523, 0.12246918, 0.12278436, 0.11647938, 0.09907566, 0.10939402, 0.11121245, - 0.09931412, -0.2015398, -0.19392101, -0.19934568, -0.19083071, -0.20022182, -0.18812077, - -0.19819336, -0.19751601, -0.18787658, -0.1910854, -0.19982933, -0.19259657, -0.1910668, - -0.19623408, -0.20643783, -0.17979786, -0.20085241, -0.20226628, -0.1943775, -0.19513902, - -0.1944603, -0.19675966, -0.20814213, -0.19372807, -0.18230462, -0.18796724, -0.19594413, - -0.19937015, -0.20221426, -0.1900377, -0.18905015, -0.20246184, -0.18973471, -0.1917036, - -0.1910854, -0.2045007, -0.20772256, -0.1910854, -0.19349803, -0.19836159, -0.20438254, - -0.16650572, -0.19694945, -0.19511227, -0.18056169, -0.19521528, -0.19218414, -0.19556037, - -0.1989097, -0.19989866, 0.110895164, 0.09209204, 0.13636513, 0.09708423, 0.12663901, - 0.11280878, 0.10437618, 0.008251642, 0.11656475, 0.062448665, 0.07663319, 0.076713376, - 0.09773914, 0.1284772, 0.0019391886, 0.08873351, 0.10645666, 0.06874694, 0.12830636, - 0.069761865, 0.12597786, 0.064558044, 0.14945637, 0.12600589, 0.08889626, 0.096229844, - 0.13689923, 0.15111938, 0.11476847, 0.012906413, 0.06886689, 0.05653629, 0.056540295, 0.1647724, - 0.1054803, 0.06795046, 0.12039944, 0.11954296, 0.052694272, 0.085520394, 0.110611565, - 0.11398453, 0.07550961, 0.023511963, 0.090924345, 0.0600122, 0.07526812, 0.088270955, - -0.03518031, 0.073293336, 0.17944553, 0.16982275, 0.1886539, 0.18693338, 0.18788463, 0.2058602, - 0.13861835, 0.20437749, 0.18895163, 0.16544276, 0.149991, 0.17463979, 0.17583887, 0.16696452, - 0.16749835, 0.1592365, 0.17954215, 0.1818188, 0.21207899, 0.15266286, 0.17395115, 0.15906107, - 0.21057771, 0.15467106, 0.17414747, 0.19151127, 0.14792846, 0.14762704, 0.1860418, 0.18808068, - 0.19654934, 0.17514904, 0.18510495, 0.16045007, 0.18320344, 0.18669076, 0.16069236, 0.17718756, - 0.14080223, 0.1681495, 0.17300002, 0.1528326, 0.16982275, 0.1817097, 0.16696694, 0.16177535, - 0.1604718, 0.16464049, 0.15210003, 0.16091338, 0.19544502, 0.1334315, 0.16168839, 0.11322618, - 0.19517533, 0.18929626, 0.17545204, 0.1665815, 0.09131178, 0.11004268, 0.20550796, 0.13831247, - 0.10610545, 0.12289211, 0.27147663, 0.20504008, 0.2518754, 0.20981932, 0.20138234, 0.19962592, - 0.15790789, 0.20949593, 0.23528637, 0.18096939, 0.08758456, 0.10911943, 0.18139273, 0.18525626, - 0.19391479, 0.11438076, 0.1093913, 0.22006766, 0.18334126, 0.21811387, 0.11004268, 0.19371085, - 0.23279056, 0.11004268, 0.11990581, 0.17242423, 0.21975593, 0.046734467, 0.1444371, 0.20759591, - 0.13962208, 0.14867997, 0.17288592, 0.14028637, 0.19978605, 0.1737019, -0.038705423, - -0.03880039, -0.060744748, 0.005578369, -0.026154364, -0.09166601, -0.061155446, 0.008943805, - -0.04777039, -0.012912485, -0.010861377, -0.01913654, -0.0061141956, -0.09119834, 0.034481876, - -0.008210908, -0.09062711, -0.0464008, -0.0038113478, -0.006515413, -0.06737334, 0.022068182, - -0.078238964, -0.10467487, -0.012385059, -0.008899481, -0.0507185, -0.0612416, -0.05302817, - 0.03657996, 0.0040081483, 0.0017336496, 0.00966107, -0.13457696, -0.106228024, -0.05810899, - -0.042826205, -0.004804179, -0.054947495, -0.0023088162, -0.083174944, -0.0812491, 0.0012216767, - 0.017188948, -0.0416347, -0.0750825, -0.052436177, -0.028371494, 0.07799446, -0.02655019, - -0.04801802, -0.11302035, -0.114139326, -0.17401277, -0.11443192, -0.19375448, -0.08697115, - -0.22462566, -0.18594599, 0.029962104, -0.03072077, -0.10795037, -0.0687454, -0.08853653, - -0.02800453, -0.0044006817, -0.14119355, -0.057319514, -0.23839943, -0.09940908, -0.03132951, - -0.07696326, -0.23962279, -0.05578459, -0.073864885, -0.16175121, -0.046830498, -0.071334355, - -0.12525235, -0.1762308, -0.17853433, -0.05481769, -0.10788009, -0.12848935, -0.21946594, - -0.07054761, -0.0043790466, -0.1421547, -0.062456187, -0.038439218, -0.01970637, 0.04187341, - -0.11302035, -0.06571084, 0.012916437, 0.008474918, -0.058553338, -0.05822342, -0.0072570713, - -0.117029555}, new int[] {150, 3}, 'c'); + -0.101017155, -0.11571029, -0.116997495, -0.123033985, -0.1222254, -0.11120513, -0.11710341, + -0.12319958, -0.124424405, -0.105285235, -0.08768927, -0.10296882, -0.11346505, -0.10607526, + -0.10681274, -0.11604863, -0.1070115, -0.114202365, -0.11168295, -0.11615404, -0.120522454, + -0.11282451, -0.11514864, -0.11681116, -0.11987897, -0.12054029, -0.112625614, -0.10337835, + -0.098809384, -0.1222254, -0.11966098, -0.11500366, -0.1222254, -0.122691356, -0.1168594, + -0.11369472, -0.11666928, -0.12075868, -0.10658686, -0.10251844, -0.119958505, -0.10873747, + -0.12036781, -0.11125211, -0.118474, 0.07354958, 0.06268418, 0.08751996, 0.05259535, 0.07969022, + 0.062334962, 0.07089297, -0.006484107, 0.0702586, 0.03601057, 0.03228142, 0.051330067, + 0.048092633, 0.0753836, 0.0026741663, 0.060346458, 0.064265735, 0.03208362, 0.07322607, + 0.034286126, 0.08459597, 0.040570714, 0.08494339, 0.06835921, 0.055334114, 0.06346921, + 0.08284429, 0.09769646, 0.07128828, 0.0012985547, 0.033257447, 0.024084045, 0.03130147, + 0.09381818, 0.062283173, 0.049273495, 0.0789609, 0.06648661, 0.030163772, 0.047266945, + 0.05704684, 0.06862679, 0.04134995, 0.0029913357, 0.050757334, 0.031863946, 0.043180045, + 0.053592253, -0.02633951, 0.04229047, 0.12401424, 0.1025523, 0.11914653, 0.10838079, + 0.119204566, 0.120582364, 0.079642124, 0.1136303, 0.103594445, 0.12434465, 0.10481718, + 0.10615024, 0.1161067, 0.101516, 0.11543929, 0.11498181, 0.1083647, 0.12498043, 0.117732316, + 0.080594465, 0.12140614, 0.10168964, 0.11630502, 0.097365364, 0.11659742, 0.11525785, + 0.095346555, 0.095523514, 0.1145297, 0.10820676, 0.113681756, 0.12088448, 0.11661095, + 0.09196416, 0.09367608, 0.12396194, 0.11715822, 0.10781161, 0.09206241, 0.11529953, 0.12193694, + 0.11471913, 0.1025523, 0.12246918, 0.12278436, 0.11647938, 0.09907566, 0.10939402, 0.11121245, + 0.09931412, -0.2015398, -0.19392101, -0.19934568, -0.19083071, -0.20022182, -0.18812077, + -0.19819336, -0.19751601, -0.18787658, -0.1910854, -0.19982933, -0.19259657, -0.1910668, + -0.19623408, -0.20643783, -0.17979786, -0.20085241, -0.20226628, -0.1943775, -0.19513902, + -0.1944603, -0.19675966, -0.20814213, -0.19372807, -0.18230462, -0.18796724, -0.19594413, + -0.19937015, -0.20221426, -0.1900377, -0.18905015, -0.20246184, -0.18973471, -0.1917036, + -0.1910854, -0.2045007, -0.20772256, -0.1910854, -0.19349803, -0.19836159, -0.20438254, + -0.16650572, -0.19694945, -0.19511227, -0.18056169, -0.19521528, -0.19218414, -0.19556037, + -0.1989097, -0.19989866, 0.110895164, 0.09209204, 0.13636513, 0.09708423, 0.12663901, + 0.11280878, 0.10437618, 0.008251642, 0.11656475, 0.062448665, 0.07663319, 0.076713376, + 0.09773914, 0.1284772, 0.0019391886, 0.08873351, 0.10645666, 0.06874694, 0.12830636, + 0.069761865, 0.12597786, 0.064558044, 0.14945637, 0.12600589, 0.08889626, 0.096229844, + 0.13689923, 0.15111938, 0.11476847, 0.012906413, 0.06886689, 0.05653629, 0.056540295, 0.1647724, + 0.1054803, 0.06795046, 0.12039944, 0.11954296, 0.052694272, 0.085520394, 0.110611565, + 0.11398453, 0.07550961, 0.023511963, 0.090924345, 0.0600122, 0.07526812, 0.088270955, + -0.03518031, 0.073293336, 0.17944553, 0.16982275, 0.1886539, 0.18693338, 0.18788463, 0.2058602, + 0.13861835, 0.20437749, 0.18895163, 0.16544276, 0.149991, 0.17463979, 0.17583887, 0.16696452, + 0.16749835, 0.1592365, 0.17954215, 0.1818188, 0.21207899, 0.15266286, 0.17395115, 0.15906107, + 0.21057771, 0.15467106, 0.17414747, 0.19151127, 0.14792846, 0.14762704, 0.1860418, 0.18808068, + 0.19654934, 0.17514904, 0.18510495, 0.16045007, 0.18320344, 0.18669076, 0.16069236, 0.17718756, + 0.14080223, 0.1681495, 0.17300002, 0.1528326, 0.16982275, 0.1817097, 0.16696694, 0.16177535, + 0.1604718, 0.16464049, 0.15210003, 0.16091338, 0.19544502, 0.1334315, 0.16168839, 0.11322618, + 0.19517533, 0.18929626, 0.17545204, 0.1665815, 0.09131178, 0.11004268, 0.20550796, 0.13831247, + 0.10610545, 0.12289211, 0.27147663, 0.20504008, 0.2518754, 0.20981932, 0.20138234, 0.19962592, + 0.15790789, 0.20949593, 0.23528637, 0.18096939, 0.08758456, 0.10911943, 0.18139273, 0.18525626, + 0.19391479, 0.11438076, 0.1093913, 0.22006766, 0.18334126, 0.21811387, 0.11004268, 0.19371085, + 0.23279056, 0.11004268, 0.11990581, 0.17242423, 0.21975593, 0.046734467, 0.1444371, 0.20759591, + 0.13962208, 0.14867997, 0.17288592, 0.14028637, 0.19978605, 0.1737019, -0.038705423, + -0.03880039, -0.060744748, 0.005578369, -0.026154364, -0.09166601, -0.061155446, 0.008943805, + -0.04777039, -0.012912485, -0.010861377, -0.01913654, -0.0061141956, -0.09119834, 0.034481876, + -0.008210908, -0.09062711, -0.0464008, -0.0038113478, -0.006515413, -0.06737334, 0.022068182, + -0.078238964, -0.10467487, -0.012385059, -0.008899481, -0.0507185, -0.0612416, -0.05302817, + 0.03657996, 0.0040081483, 0.0017336496, 0.00966107, -0.13457696, -0.106228024, -0.05810899, + -0.042826205, -0.004804179, -0.054947495, -0.0023088162, -0.083174944, -0.0812491, 0.0012216767, + 0.017188948, -0.0416347, -0.0750825, -0.052436177, -0.028371494, 0.07799446, -0.02655019, + -0.04801802, -0.11302035, -0.114139326, -0.17401277, -0.11443192, -0.19375448, -0.08697115, + -0.22462566, -0.18594599, 0.029962104, -0.03072077, -0.10795037, -0.0687454, -0.08853653, + -0.02800453, -0.0044006817, -0.14119355, -0.057319514, -0.23839943, -0.09940908, -0.03132951, + -0.07696326, -0.23962279, -0.05578459, -0.073864885, -0.16175121, -0.046830498, -0.071334355, + -0.12525235, -0.1762308, -0.17853433, -0.05481769, -0.10788009, -0.12848935, -0.21946594, + -0.07054761, -0.0043790466, -0.1421547, -0.062456187, -0.038439218, -0.01970637, 0.04187341, + -0.11302035, -0.06571084, 0.012916437, 0.008474918, -0.058553338, -0.05822342, -0.0072570713, + -0.117029555}, new int[] {150, 3}, 'c'); INDArray assertion = Nd4j.create(new double[] {-1.0949919, -1.1009998, -1.0998554, -1.1079034, -1.1003298, - -1.0877079, -1.0957471, -1.0970343, -1.1030709, -1.1040032, -1.0929829, -1.0988811, -1.1042137, - -1.1054386, -1.0862994, -1.0849832, -1.1002628, -1.110759, -1.0950522, -1.0957897, -1.1050256, - -1.0946627, -1.1018535, -1.0993341, -1.098271, -1.1026394, -1.0949415, -1.0964833, -1.0981458, - -1.1012137, -1.1069958, -1.0990812, -1.0898339, -1.0839114, -1.1073275, -1.104763, -1.0936487, - -1.1008704, -1.1013364, -1.0997316, -1.0965669, -1.0995414, -1.1094468, -1.0952749, -1.0912066, - -1.1022308, -1.0910097, -1.10264, -1.1618325, -1.1690543, -0.97703075, -1.1036359, -1.0788001, - -1.1137247, -1.0899199, -1.1072751, -1.0987172, -1.13885, -1.0621073, -1.0963553, -1.1102668, - -1.0912181, -1.0944556, -1.0698514, -1.1425608, -1.0848886, -1.0910273, -1.1232094, -1.0820669, - -1.1177288, -1.0674189, -1.1114442, -1.083288, -1.0998721, -1.1128973, -1.1165779, -1.0972028, - -1.0823506, -1.063015, -1.1330047, -1.1010458, -1.1247563, -1.1175389, -1.0550222, -1.0999088, - -1.1129185, -1.0832311, -1.0802083, -1.1165311, -1.0994279, -1.0973024, -1.0857224, -1.1129993, - -1.124351, -1.076585, -1.0954784, -1.0795343, -1.0691221, -1.1490538, -1.1465356, -1.0648118, - -1.0862738, -1.0950559, -1.1058216, -1.0949979, -1.0828075, -1.1237478, -1.0897596, -1.1059818, - -1.0852317, -1.1047591, -1.100405, -1.0904485, -1.1050392, -1.0961069, -1.0965644, -1.1031815, - -1.0815891, -1.0888373, -1.125975, -1.0903746, -1.1100911, -1.0954757, -1.1110255, -1.0917934, - -1.093133, -1.1051062, -1.1049292, -1.0859231, -1.1046766, -1.0992017, -1.0919989, -1.082815, - -1.1074618, -1.10575, -1.0909829, -1.0977867, -1.1071333, -1.116398, -1.0931609, -1.0865234, - -1.0971736, -1.1093404, -1.0894235, -1.0886579, -1.0949628, -1.1123666, -1.095872, -1.0940536, - -1.1059519, -1.1018884, -1.0942696, -1.0996943, -1.0963987, -1.1057898, -1.0936887, -1.102288, - -1.1016107, -1.0919713, -1.0952013, -1.1039451, -1.0967125, -1.0917866, -1.0969539, -1.1071577, - -1.0841576, -1.1052121, -1.106626, -1.098331, -1.0990925, -1.0984138, -1.095848, -1.1072304, - -1.0928164, -1.0921938, -1.0978565, -1.1058333, -1.1007886, -1.1036327, -1.0914562, -1.0939325, - -1.1073442, -1.0946171, -1.0945718, -1.0939536, -1.107369, -1.1089264, -1.0922892, -1.0947019, - -1.1073625, -1.1133835, -1.0755067, -1.1047142, -1.102877, -1.0883265, -1.0995088, -1.0964776, - -1.0998539, -1.2125868, -1.2135757, -0.9027819, -1.115231, -1.0709579, -1.1102388, -1.0866234, - -1.1004536, -1.1088862, -1.1537597, -1.0454466, -1.0995628, -1.1057239, -1.1056436, -1.0846179, - -1.0445701, -1.1711081, -1.0843138, -1.0936275, -1.1313372, -1.0717777, -1.1160054, -1.0597894, - -1.1212093, -1.0709189, -1.0943694, -1.131479, -1.1307347, -1.0900652, -1.0758451, -1.0502236, - -1.1520857, -1.0961251, -1.1360092, -1.1360053, -1.0277731, -1.091318, -1.1288478, -1.0763988, - -1.065361, -1.1322097, -1.0993836, -1.0881867, -1.0848137, -1.1232886, -1.133629, -1.0662166, - -1.0971287, -1.0676445, -1.0546416, -1.1780928, -1.1673087, -1.0611565, -1.0707793, -1.0977826, - -1.0995032, -1.0985519, -1.0761919, -1.1434338, -1.0776746, -1.0779177, -1.1014266, -1.1168783, - -1.0964613, -1.0952622, -1.1041365, -1.0999078, -1.1081696, -1.0878639, -1.0992746, -1.0690144, - -1.1284306, -1.1060928, -1.1209829, -1.0694662, -1.1174977, -1.0980213, -1.0806575, -1.1113796, - -1.111681, -1.0732663, -1.0971633, -1.0886947, -1.110095, -1.0898226, -1.1144775, -1.0917242, - -1.0868361, -1.1128345, -1.0963393, -1.1185608, -1.0912135, -1.086363, -1.1139716, -1.0969814, - -1.0850945, -1.0947206, -1.0999122, -1.1012157, -1.0932035, -1.105744, -1.0969306, -1.0670104, - -1.1290239, -1.100767, -1.1519758, -1.0700266, -1.0759057, -1.0683149, -1.0771854, -1.1524552, - -1.1406635, -1.0451982, -1.1123937, -1.1621376, -1.1453509, -0.99676645, -1.1160396, -1.0692043, - -1.1112604, -1.0837362, -1.0854926, -1.1272106, -1.0979462, -1.0721557, -1.1264727, -1.1378707, - -1.1163357, -1.0440625, -1.0785028, -1.0698442, -1.1493783, -1.1612072, -1.0505308, -1.0872571, - -1.0555155, -1.1635867, -1.0799185, -1.0216377, -1.1443856, -1.1345224, -1.0751246, -1.0277929, - -1.2008144, -1.1185431, -1.0553844, -1.1233582, -1.1039788, -1.0797728, -1.1123724, -1.0159799, - -1.0420641, -1.2544713, -1.1064723, -1.1284167, -1.0620935, -1.0654664, -1.1309781, -1.1004674, - -1.0726943, -1.1294085, -1.0945506, -1.0974507, -1.1057259, -1.0927036, -1.1695204, -1.0438402, - -1.086533, -1.1429209, -1.0986946, -1.0561051, -1.0885462, -1.149404, -1.0599625, -1.112509, - -1.1389449, -1.046655, -1.0674819, -1.1093009, -1.119824, -1.1481767, -1.0585686, -1.0911404, - -1.0579745, -1.050047, -1.194285, -1.136149, -1.08803, -1.0727472, -1.0830219, -1.1331651, - -1.0805265, -1.1281672, -1.1262413, -1.0437706, -1.0489775, -1.1078012, -1.141249, -1.1517346, - -1.1276698, -1.0213039, -1.0633042, -1.084772, -1.1497743, -1.0789506, -1.1388241, -1.0792432, - -1.125674, -1.0188907, -1.1565453, -1.2263924, -1.0104843, -1.0711672, -1.1182799, -1.079075, - -1.0988661, -1.0705098, -1.046906, -1.1836989, -1.0271709, -1.2082508, -1.0692605, -1.017894, - -1.0635278, -1.2261873, -1.0583237, -1.0764041, -1.1642903, -1.0648377, -1.0893415, -1.1432595, - -1.140007, -1.1423105, -1.0185939, -1.0557104, -1.0763197, -1.1672963, -1.09838, -1.0322114, - -1.1699871, -1.1210208, -1.0970039, -1.078271, -1.0132385, -1.1681323, -1.1208228, -1.0738388, - -1.0782803, -1.1453086, -1.0970035, -1.0460371, -1.1558095}, new int[] {150, 3}, 'c'); + -1.0877079, -1.0957471, -1.0970343, -1.1030709, -1.1040032, -1.0929829, -1.0988811, -1.1042137, + -1.1054386, -1.0862994, -1.0849832, -1.1002628, -1.110759, -1.0950522, -1.0957897, -1.1050256, + -1.0946627, -1.1018535, -1.0993341, -1.098271, -1.1026394, -1.0949415, -1.0964833, -1.0981458, + -1.1012137, -1.1069958, -1.0990812, -1.0898339, -1.0839114, -1.1073275, -1.104763, -1.0936487, + -1.1008704, -1.1013364, -1.0997316, -1.0965669, -1.0995414, -1.1094468, -1.0952749, -1.0912066, + -1.1022308, -1.0910097, -1.10264, -1.1618325, -1.1690543, -0.97703075, -1.1036359, -1.0788001, + -1.1137247, -1.0899199, -1.1072751, -1.0987172, -1.13885, -1.0621073, -1.0963553, -1.1102668, + -1.0912181, -1.0944556, -1.0698514, -1.1425608, -1.0848886, -1.0910273, -1.1232094, -1.0820669, + -1.1177288, -1.0674189, -1.1114442, -1.083288, -1.0998721, -1.1128973, -1.1165779, -1.0972028, + -1.0823506, -1.063015, -1.1330047, -1.1010458, -1.1247563, -1.1175389, -1.0550222, -1.0999088, + -1.1129185, -1.0832311, -1.0802083, -1.1165311, -1.0994279, -1.0973024, -1.0857224, -1.1129993, + -1.124351, -1.076585, -1.0954784, -1.0795343, -1.0691221, -1.1490538, -1.1465356, -1.0648118, + -1.0862738, -1.0950559, -1.1058216, -1.0949979, -1.0828075, -1.1237478, -1.0897596, -1.1059818, + -1.0852317, -1.1047591, -1.100405, -1.0904485, -1.1050392, -1.0961069, -1.0965644, -1.1031815, + -1.0815891, -1.0888373, -1.125975, -1.0903746, -1.1100911, -1.0954757, -1.1110255, -1.0917934, + -1.093133, -1.1051062, -1.1049292, -1.0859231, -1.1046766, -1.0992017, -1.0919989, -1.082815, + -1.1074618, -1.10575, -1.0909829, -1.0977867, -1.1071333, -1.116398, -1.0931609, -1.0865234, + -1.0971736, -1.1093404, -1.0894235, -1.0886579, -1.0949628, -1.1123666, -1.095872, -1.0940536, + -1.1059519, -1.1018884, -1.0942696, -1.0996943, -1.0963987, -1.1057898, -1.0936887, -1.102288, + -1.1016107, -1.0919713, -1.0952013, -1.1039451, -1.0967125, -1.0917866, -1.0969539, -1.1071577, + -1.0841576, -1.1052121, -1.106626, -1.098331, -1.0990925, -1.0984138, -1.095848, -1.1072304, + -1.0928164, -1.0921938, -1.0978565, -1.1058333, -1.1007886, -1.1036327, -1.0914562, -1.0939325, + -1.1073442, -1.0946171, -1.0945718, -1.0939536, -1.107369, -1.1089264, -1.0922892, -1.0947019, + -1.1073625, -1.1133835, -1.0755067, -1.1047142, -1.102877, -1.0883265, -1.0995088, -1.0964776, + -1.0998539, -1.2125868, -1.2135757, -0.9027819, -1.115231, -1.0709579, -1.1102388, -1.0866234, + -1.1004536, -1.1088862, -1.1537597, -1.0454466, -1.0995628, -1.1057239, -1.1056436, -1.0846179, + -1.0445701, -1.1711081, -1.0843138, -1.0936275, -1.1313372, -1.0717777, -1.1160054, -1.0597894, + -1.1212093, -1.0709189, -1.0943694, -1.131479, -1.1307347, -1.0900652, -1.0758451, -1.0502236, + -1.1520857, -1.0961251, -1.1360092, -1.1360053, -1.0277731, -1.091318, -1.1288478, -1.0763988, + -1.065361, -1.1322097, -1.0993836, -1.0881867, -1.0848137, -1.1232886, -1.133629, -1.0662166, + -1.0971287, -1.0676445, -1.0546416, -1.1780928, -1.1673087, -1.0611565, -1.0707793, -1.0977826, + -1.0995032, -1.0985519, -1.0761919, -1.1434338, -1.0776746, -1.0779177, -1.1014266, -1.1168783, + -1.0964613, -1.0952622, -1.1041365, -1.0999078, -1.1081696, -1.0878639, -1.0992746, -1.0690144, + -1.1284306, -1.1060928, -1.1209829, -1.0694662, -1.1174977, -1.0980213, -1.0806575, -1.1113796, + -1.111681, -1.0732663, -1.0971633, -1.0886947, -1.110095, -1.0898226, -1.1144775, -1.0917242, + -1.0868361, -1.1128345, -1.0963393, -1.1185608, -1.0912135, -1.086363, -1.1139716, -1.0969814, + -1.0850945, -1.0947206, -1.0999122, -1.1012157, -1.0932035, -1.105744, -1.0969306, -1.0670104, + -1.1290239, -1.100767, -1.1519758, -1.0700266, -1.0759057, -1.0683149, -1.0771854, -1.1524552, + -1.1406635, -1.0451982, -1.1123937, -1.1621376, -1.1453509, -0.99676645, -1.1160396, -1.0692043, + -1.1112604, -1.0837362, -1.0854926, -1.1272106, -1.0979462, -1.0721557, -1.1264727, -1.1378707, + -1.1163357, -1.0440625, -1.0785028, -1.0698442, -1.1493783, -1.1612072, -1.0505308, -1.0872571, + -1.0555155, -1.1635867, -1.0799185, -1.0216377, -1.1443856, -1.1345224, -1.0751246, -1.0277929, + -1.2008144, -1.1185431, -1.0553844, -1.1233582, -1.1039788, -1.0797728, -1.1123724, -1.0159799, + -1.0420641, -1.2544713, -1.1064723, -1.1284167, -1.0620935, -1.0654664, -1.1309781, -1.1004674, + -1.0726943, -1.1294085, -1.0945506, -1.0974507, -1.1057259, -1.0927036, -1.1695204, -1.0438402, + -1.086533, -1.1429209, -1.0986946, -1.0561051, -1.0885462, -1.149404, -1.0599625, -1.112509, + -1.1389449, -1.046655, -1.0674819, -1.1093009, -1.119824, -1.1481767, -1.0585686, -1.0911404, + -1.0579745, -1.050047, -1.194285, -1.136149, -1.08803, -1.0727472, -1.0830219, -1.1331651, + -1.0805265, -1.1281672, -1.1262413, -1.0437706, -1.0489775, -1.1078012, -1.141249, -1.1517346, + -1.1276698, -1.0213039, -1.0633042, -1.084772, -1.1497743, -1.0789506, -1.1388241, -1.0792432, + -1.125674, -1.0188907, -1.1565453, -1.2263924, -1.0104843, -1.0711672, -1.1182799, -1.079075, + -1.0988661, -1.0705098, -1.046906, -1.1836989, -1.0271709, -1.2082508, -1.0692605, -1.017894, + -1.0635278, -1.2261873, -1.0583237, -1.0764041, -1.1642903, -1.0648377, -1.0893415, -1.1432595, + -1.140007, -1.1423105, -1.0185939, -1.0557104, -1.0763197, -1.1672963, -1.09838, -1.0322114, + -1.1699871, -1.1210208, -1.0970039, -1.078271, -1.0132385, -1.1681323, -1.1208228, -1.0738388, + -1.0782803, -1.1453086, -1.0970035, -1.0460371, -1.1558095}, new int[] {150, 3}, 'c'); Nd4j.getExecutioner().exec(new LogSoftMax(test)); assertEquals(assertion, test); @@ -781,20 +868,24 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testSoftmax() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSoftmax(Nd4jBackend backend) { INDArray vec = Nd4j.linspace(1, 18, 18, DataType.DOUBLE); INDArray matrix = vec.dup().reshape(3, 6); Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix)); INDArray assertion = Nd4j.create( - new double[] {0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913, - 0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913, - 0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913}, - new int[] {3, 6}, 'c'); + new double[] {0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913, + 0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913, + 0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913}, + new int[] {3, 6}, 'c'); assertEquals(assertion, matrix); } @Test - public void testStdev() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStdev(Nd4jBackend backend) { INDArray arr = Nd4j.create(new float[] {0.9296161f, 0.31637555f, 0.1839188f}, new int[] {1, 3}, ordering()); double stdev = arr.stdNumber().doubleValue(); double stdev2 = arr.std(1).getDouble(0); @@ -805,7 +896,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testVariance() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariance(Nd4jBackend backend) { INDArray arr = Nd4j.create(new float[] {0.9296161f, 0.31637555f, 0.1839188f}, new int[] {1, 3}, ordering()); double var = arr.varNumber().doubleValue(); @@ -818,7 +911,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testEpsOps() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEpsOps(Nd4jBackend backend) { INDArray ones = Nd4j.ones(DataType.DOUBLE, 1, 6); double tiny = 1.000000000000001; assertTrue(ones.eps(tiny).all()); @@ -829,7 +924,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testVarianceSingleVsMultipleDimensions() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVarianceSingleVsMultipleDimensions(Nd4jBackend backend) { // this test should always run in double DataType type = Nd4j.dataType(); DataTypeUtil.setDTypeForContext(DataType.DOUBLE); @@ -874,7 +971,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testHistogram1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHistogram1(Nd4jBackend backend) { INDArray x = Nd4j.linspace(1, 1000, 100000, DataType.DOUBLE); INDArray z = Nd4j.zeros(DataType.LONG,new long[]{20}); @@ -896,7 +995,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testHistogram2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHistogram2(Nd4jBackend backend) { INDArray x = Nd4j.create(new float[] {0f, 0f, 0f, 5f, 5f, 5f, 10f, 10f, 10f}); INDArray xDup = x.dup(); @@ -915,7 +1016,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testEuclideanManhattanDistanceAlongDimension_Rank4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEuclideanManhattanDistanceAlongDimension_Rank4(Nd4jBackend backend) { DataType initialType = Nd4j.dataType(); DataTypeUtil.setDTypeForContext(DataType.DOUBLE); Nd4j.getRandom().setSeed(12345); @@ -953,9 +1056,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { INDArray out = Nd4j.getExecutioner().exec(new EuclideanDistance(first, second, 1, 2, 3)); Pair firstTadInfo = - Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(first, 1, 2, 3); + Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(first, 1, 2, 3); Pair secondTadInfo = - Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(second, 1, 2, 3); + Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(second, 1, 2, 3); INDArray outManhattan = Nd4j.getExecutioner().exec(new ManhattanDistance(first, second, 1, 2, 3)); @@ -979,7 +1082,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testPile1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPile1(Nd4jBackend backend) { List arrays = new ArrayList<>(); for (int i = 0; i < 10; i++) { arrays.add(Nd4j.create(10, 10).assign(i)); @@ -994,7 +1099,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testPile2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPile2(Nd4jBackend backend) { List arrays = new ArrayList<>(); for (int i = 0; i < 10; i++) { arrays.add(Nd4j.create(10, 10, 10).assign(i)); @@ -1009,7 +1116,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testMean1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMean1(Nd4jBackend backend) { INDArray array = Nd4j.create(32, 100, 100).assign(-119f); for (int i = 0; i < 32; i++) { val tad = array.tensorAlongDimension(i, 1, 2); @@ -1025,7 +1134,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { } @Test - public void testMean2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMean2(Nd4jBackend backend) { INDArray array = Nd4j.create(32, 100, 100); for (int i = 0; i < 32; i++) { array.tensorAlongDimension(i, 1, 2).assign((float) 100 + i); @@ -1039,14 +1150,18 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testNorm2_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm2_1(Nd4jBackend backend) { INDArray array = Nd4j.rand(1769472, 9); INDArray max = array.max(1); } @Test - public void testNorm2_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNorm2_2(Nd4jBackend backend) { INDArray array = Nd4j.rand(new int[]{127, 164}, 1, 100, Nd4j.getRandom()); double norm2 = array.norm2Number().doubleValue(); @@ -1060,7 +1175,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { */ @Test @Disabled - public void testTadEws() { + public void testTadEws(Nd4jBackend backend) { INDArray array = Nd4j.create(32, 5, 10); assertEquals(1, array.tensorAlongDimension(0, 1, 2).elementWiseStride()); } @@ -1068,7 +1183,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test - public void testTear1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTear1(Nd4jBackend backend) { List arrays = new ArrayList<>(); val num = 10; for (int i = 0; i < num; i++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/RationalTanhTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/RationalTanhTest.java index 045e3e78c..83221b2e6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/RationalTanhTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/RationalTanhTest.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.ops; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative; import org.nd4j.linalg.factory.Nd4j; @@ -31,15 +32,13 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertTrue; -@RunWith(Parameterized.class) -public class RationalTanhTest extends BaseNd4jTest { - public RationalTanhTest(Nd4jBackend backend) { - super(backend); - } +public class RationalTanhTest extends BaseNd4jTestWithBackends { @Test - public void gradientCheck() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void gradientCheck(Nd4jBackend backend) { double eps = 1e-6; INDArray A = Nd4j.linspace(-3, 3, 10).reshape(2, 5); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/broadcast/row/RowVectorOpsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/broadcast/row/RowVectorOpsC.java index 8971e80b2..3ba7d0841 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/broadcast/row/RowVectorOpsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/broadcast/row/RowVectorOpsC.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.ops.broadcast.row; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -34,16 +35,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class RowVectorOpsC extends BaseNd4jTest { - public RowVectorOpsC(Nd4jBackend backend) { - super(backend); - } +public class RowVectorOpsC extends BaseNd4jTestWithBackends { + @Test - public void testAddi() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAddi(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); arr.addiRowVector(Nd4j.create(new double[] {1, 2})); INDArray assertion = Nd4j.create(new double[][] {{2, 4}, {4, 6}}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/copy/CopyTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/copy/CopyTest.java index 21f70bc67..fdbcc7770 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/copy/CopyTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/copy/CopyTest.java @@ -21,31 +21,33 @@ package org.nd4j.linalg.ops.copy; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) -public class CopyTest extends BaseNd4jTest { - public CopyTest(Nd4jBackend backend) { - super(backend); - } + +public class CopyTest extends BaseNd4jTestWithBackends { @Test - public void testCopy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCopy(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray dup = arr.dup(); assertEquals(arr, dup); } @Test - public void testDup() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDup(Nd4jBackend backend) { for (int x = 0; x < 100; x++) { INDArray orig = Nd4j.linspace(1, 4, 4); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/options/ArrayOptionsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/options/ArrayOptionsTests.java index a27a92c76..3c40a338a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/options/ArrayOptionsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/options/ArrayOptionsTests.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.options; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayType; @@ -35,13 +36,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; @Slf4j -@RunWith(Parameterized.class) -public class ArrayOptionsTests extends BaseNd4jTest { + +public class ArrayOptionsTests extends BaseNd4jTestWithBackends { private static long[] shapeInfo; - public ArrayOptionsTests(Nd4jBackend backend) { - super(backend); - } @BeforeEach @@ -50,33 +48,43 @@ public class ArrayOptionsTests extends BaseNd4jTest { } @Test - public void testArrayType_0() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArrayType_0(Nd4jBackend backend) { assertEquals(ArrayType.DENSE, ArrayOptionsHelper.arrayType(shapeInfo)); } @Test - public void testArrayType_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArrayType_1(Nd4jBackend backend) { ArrayOptionsHelper.setOptionBit(shapeInfo, ArrayType.EMPTY); assertEquals(ArrayType.EMPTY, ArrayOptionsHelper.arrayType(shapeInfo)); } @Test - public void testArrayType_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArrayType_2(Nd4jBackend backend) { ArrayOptionsHelper.setOptionBit(shapeInfo, ArrayType.SPARSE); assertEquals(ArrayType.SPARSE, ArrayOptionsHelper.arrayType(shapeInfo)); } @Test - public void testArrayType_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArrayType_3(Nd4jBackend backend) { ArrayOptionsHelper.setOptionBit(shapeInfo, ArrayType.COMPRESSED); assertEquals(ArrayType.COMPRESSED, ArrayOptionsHelper.arrayType(shapeInfo)); } @Test - public void testDataTypesToFromLong(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDataTypesToFromLong(Nd4jBackend backend) { for(DataType dt : DataType.values()){ if(dt == DataType.UNKNOWN) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java index 898232888..48b8944e3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.profiling; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil; @@ -35,12 +36,9 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertThrows; -@RunWith(Parameterized.class) -public class InfNanTests extends BaseNd4jTest { - public InfNanTests(Nd4jBackend backend) { - super(backend); - } +public class InfNanTests extends BaseNd4jTestWithBackends { + @BeforeEach public void setUp() { @@ -53,22 +51,26 @@ public class InfNanTests extends BaseNd4jTest { } @Test() - public void testInf1() { - assertThrows(ND4JIllegalStateException.class,() -> { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC); + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInf1(Nd4jBackend backend) { + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC); - INDArray x = Nd4j.create(100); + INDArray x = Nd4j.create(100); - x.putScalar(2, Float.NEGATIVE_INFINITY); + x.putScalar(2, Float.NEGATIVE_INFINITY); - OpExecutionerUtil.checkForAny(x); - }); + OpExecutionerUtil.checkForAny(x); + }); } @Test() - public void testInf2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInf2(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); @@ -82,7 +84,9 @@ public class InfNanTests extends BaseNd4jTest { } @Test - public void testInf3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInf3(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); INDArray x = Nd4j.create(100); @@ -91,7 +95,9 @@ public class InfNanTests extends BaseNd4jTest { } @Test - public void testInf4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInf4(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); INDArray x = Nd4j.create(100); @@ -100,22 +106,26 @@ public class InfNanTests extends BaseNd4jTest { } @Test() - public void testNaN1() { - assertThrows(ND4JIllegalStateException.class,() -> { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNaN1(Nd4jBackend backend) { + assertThrows(ND4JIllegalStateException.class,() -> { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); - INDArray x = Nd4j.create(100); + INDArray x = Nd4j.create(100); - x.putScalar(2, Float.NaN); + x.putScalar(2, Float.NaN); - OpExecutionerUtil.checkForAny(x); + OpExecutionerUtil.checkForAny(x); }); } @Test() - public void testNaN2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNaN2(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); @@ -129,7 +139,9 @@ public class InfNanTests extends BaseNd4jTest { } @Test - public void testNaN3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNaN3(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); INDArray x = Nd4j.create(100); @@ -138,7 +150,9 @@ public class InfNanTests extends BaseNd4jTest { } @Test - public void testNaN4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNaN4(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); INDArray x = Nd4j.create(100); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java index 5312f7adf..b942b8430 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java @@ -27,7 +27,9 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -48,11 +50,8 @@ import org.nd4j.linalg.profiler.ProfilerConfig; import static org.junit.jupiter.api.Assertions.*; @Slf4j -public class OperationProfilerTests extends BaseNd4jTest { +public class OperationProfilerTests extends BaseNd4jTestWithBackends { - public OperationProfilerTests(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -71,7 +70,9 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testCounter1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCounter1(Nd4jBackend backend) { INDArray array = Nd4j.createUninitialized(100); array.assign(10f); @@ -82,7 +83,9 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test - public void testStack1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStack1(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL); @@ -99,7 +102,9 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test - public void testBadCombos1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBadCombos1(Nd4jBackend backend) { INDArray x = Nd4j.create(100); INDArray y = Nd4j.create(100); @@ -110,7 +115,9 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testBadCombos2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBadCombos2(Nd4jBackend backend) { INDArray x = Nd4j.create(100).reshape('f', 10, 10); INDArray y = Nd4j.create(100).reshape('c', 10, 10); @@ -121,7 +128,9 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testBadCombos3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBadCombos3(Nd4jBackend backend) { INDArray x = Nd4j.create(27).reshape('c', 3, 3, 3).tensorAlongDimension(0, 1, 2); INDArray y = Nd4j.create(100).reshape('f', 10, 10); @@ -134,7 +143,9 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testBadCombos4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBadCombos4(Nd4jBackend backend) { INDArray x = Nd4j.create(27).reshape('c', 3, 3, 3).tensorAlongDimension(0, 1, 2); INDArray y = Nd4j.create(100).reshape('f', 10, 10); INDArray z = Nd4j.create(100).reshape('f', 10, 10); @@ -148,7 +159,9 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testBadCombos5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBadCombos5(Nd4jBackend backend) { INDArray w = Nd4j.create(100).reshape('c', 10, 10); INDArray x = Nd4j.create(100).reshape('c', 10, 10); INDArray y = Nd4j.create(100).reshape('f', 10, 10); @@ -163,7 +176,7 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test @Disabled - public void testBadCombos6() { + public void testBadCombos6(Nd4jBackend backend) { INDArray x = Nd4j.create(27).reshape('f', 3, 3, 3).slice(1); INDArray y = Nd4j.create(100).reshape('f', 10, 10); @@ -175,11 +188,13 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testBadTad1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBadTad1(Nd4jBackend backend) { INDArray x = Nd4j.create(2, 4, 5, 6); Pair pair = - Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 0, 2); + Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 0, 2); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); @@ -189,11 +204,13 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testBadTad2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBadTad2(Nd4jBackend backend) { INDArray x = Nd4j.create(2, 4, 5, 6); Pair pair = - Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 2, 3); + Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 2, 3); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); @@ -205,11 +222,13 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test - public void testBadTad3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBadTad3(Nd4jBackend backend) { INDArray x = Nd4j.create(new int[] {2, 4, 5, 6, 7}, 'f'); Pair pair = - Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 0, 2, 4); + Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 0, 2, 4); OpProfiler.PenaltyCause[] causes = OpProfiler.getInstance().processTADOperands(pair.getFirst()); @@ -220,7 +239,7 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test @Disabled - public void testBadTad4() { + public void testBadTad4(Nd4jBackend backend) { INDArray x = Nd4j.create(2, 4, 5, 6); Pair pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 3); @@ -234,7 +253,9 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testBadTad5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBadTad5(Nd4jBackend backend) { INDArray x = Nd4j.create(new int[] {2, 4, 5, 6, 7}, 'f'); Pair pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 4); @@ -249,7 +270,9 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test - public void testCxFxF1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCxFxF1(Nd4jBackend backend) { INDArray a = Nd4j.create(10, 10).reshape('f', 10, 10); INDArray b = Nd4j.create(10, 10).reshape('c', 10, 10); INDArray c = Nd4j.create(10, 10).reshape('f', 10, 10); @@ -259,7 +282,9 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testCxFxF2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCxFxF2(Nd4jBackend backend) { INDArray a = Nd4j.create(10, 10).reshape('c', 10, 10); INDArray b = Nd4j.create(10, 10).reshape('c', 10, 10); INDArray c = Nd4j.create(10, 10).reshape('f', 10, 10); @@ -269,7 +294,9 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testCxFxF3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCxFxF3(Nd4jBackend backend) { INDArray a = Nd4j.create(10, 10).reshape('c', 10, 10); INDArray b = Nd4j.create(10, 10).reshape('c', 10, 10); INDArray c = Nd4j.create(10, 10).reshape('c', 10, 10); @@ -280,7 +307,9 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test - public void testBlasFF() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBlasFF(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL); INDArray a = Nd4j.create(10, 10).reshape('f', 10, 10); @@ -293,7 +322,7 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test() - public void testNaNPanic1() { + public void testNaNPanic1(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); @@ -305,7 +334,7 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test() - public void testNaNPanic2() { + public void testNaNPanic2(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC); @@ -317,7 +346,7 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test() - public void testNaNPanic3() { + public void testNaNPanic3(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); @@ -330,7 +359,7 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test() - public void testScopePanic1() { + public void testScopePanic1(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); @@ -349,7 +378,7 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test() - public void testScopePanic2() { + public void testScopePanic2(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); @@ -376,7 +405,9 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test - public void testScopePanic3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScopePanic3(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); @@ -396,7 +427,9 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testScopePanicPerf() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScopePanicPerf(Nd4jBackend backend) { try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS121")) { INDArray x = Nd4j.create(1000, 1000).assign(1.0); INDArray y = Nd4j.create(1000, 1000).assign(1.0); @@ -434,7 +467,9 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test - public void testExtendedStatistics() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExtendedStatistics(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().nativeStatistics(true).build()); INDArray array = Nd4j.ones(10); @@ -449,6 +484,8 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testNanPanic(){ try { DynamicCustomOp op = DynamicCustomOp.builder("add") @@ -480,6 +517,8 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testInfPanic(){ try { DynamicCustomOp op = DynamicCustomOp.builder("add") @@ -511,6 +550,8 @@ public class OperationProfilerTests extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testOpProfilerOpContextLegacy(){ for(boolean nan : new boolean[]{true, false}) { @@ -534,6 +575,8 @@ public class OperationProfilerTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testOpProfilerOpContextCustomOp(){ for(boolean nan : new boolean[]{true, false}) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java index 9e7c68979..db17f7c3f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java @@ -26,9 +26,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.ops.performance.primitives.AveragingTransactionsHolder; @@ -40,26 +41,25 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -@RunWith(Parameterized.class) -public class PerformanceTrackerTests extends BaseNd4jTest { - public PerformanceTrackerTests(Nd4jBackend backend) { - super(backend); - } + +public class PerformanceTrackerTests extends BaseNd4jTestWithBackends { @BeforeEach - public void setUp() { + public void setUp(Nd4jBackend backend) { PerformanceTracker.getInstance().clear(); Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.BANDWIDTH); } @AfterEach - public void tearDown() { + public void tearDown(Nd4jBackend backend) { PerformanceTracker.getInstance().clear(); Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); } @Test - public void testAveragedHolder_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAveragedHolder_1(Nd4jBackend backend) { val holder = new AveragingTransactionsHolder(); holder.addValue(MemcpyDirection.HOST_TO_HOST,50L); @@ -69,7 +69,9 @@ public class PerformanceTrackerTests extends BaseNd4jTest { } @Test - public void testAveragedHolder_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAveragedHolder_2(Nd4jBackend backend) { val holder = new AveragingTransactionsHolder(); holder.addValue(MemcpyDirection.HOST_TO_HOST, 50L); @@ -80,7 +82,9 @@ public class PerformanceTrackerTests extends BaseNd4jTest { } @Test - public void testPerformanceTracker_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPerformanceTracker_1(Nd4jBackend backend) { val perf = PerformanceTracker.getInstance(); // 100 nanoseconds spent for 5000 bytes. result should be around 50000 bytes per microsecond @@ -89,7 +93,9 @@ public class PerformanceTrackerTests extends BaseNd4jTest { } @Test - public void testPerformanceTracker_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPerformanceTracker_2(Nd4jBackend backend) { val perf = PerformanceTracker.getInstance(); // 10 nanoseconds spent for 5000 bytes. result should be around 500000 bytes per microsecond @@ -98,7 +104,9 @@ public class PerformanceTrackerTests extends BaseNd4jTest { } @Test - public void testPerformanceTracker_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPerformanceTracker_3(Nd4jBackend backend) { val perf = PerformanceTracker.getInstance(); // 10000 nanoseconds spent for 5000 bytes. result should be around 500 bytes per microsecond @@ -108,7 +116,7 @@ public class PerformanceTrackerTests extends BaseNd4jTest { @Test @Disabled - public void testTrackerCpu_1() { + public void testTrackerCpu_1(Nd4jBackend backend) { if (!Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("native")) return; @@ -126,7 +134,7 @@ public class PerformanceTrackerTests extends BaseNd4jTest { @Test @Disabled("useless these days") - public void testTrackerGpu_1() { + public void testTrackerGpu_1(Nd4jBackend backend) { if (!Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda")) return; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java index c0f5470a8..2cb5048fe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java @@ -25,7 +25,9 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.factory.Nd4j; @@ -39,11 +41,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -public class StackAggregatorTests extends BaseNd4jTest { +public class StackAggregatorTests extends BaseNd4jTestWithBackends { - public StackAggregatorTests(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -51,20 +50,22 @@ public class StackAggregatorTests extends BaseNd4jTest { } @BeforeEach - public void setUp() { + public void setUp(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().stackTrace(true).build()); Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL); OpProfiler.getInstance().reset(); } @AfterEach - public void tearDown() { + public void tearDown(Nd4jBackend backend) { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); } @Test - public void testBasicBranching1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicBranching1(Nd4jBackend backend) { StackAggregator aggregator = new StackAggregator(); aggregator.incrementCount(); @@ -76,7 +77,9 @@ public class StackAggregatorTests extends BaseNd4jTest { } @Test - public void testBasicBranching2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicBranching2(Nd4jBackend backend) { StackAggregator aggregator = new StackAggregator(); for (int i = 0; i < 10; i++) { @@ -91,7 +94,9 @@ public class StackAggregatorTests extends BaseNd4jTest { @Test - public void testTrailingFrames1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTrailingFrames1(Nd4jBackend backend) { StackAggregator aggregator = new StackAggregator(); aggregator.incrementCount(); @@ -104,8 +109,10 @@ public class StackAggregatorTests extends BaseNd4jTest { assertTrue(descriptor.getStackTrace()[descriptor.size() - 1].getClassName().contains("StackAggregatorTests")); } - /*@Test - public void testTrailingFrames2() { + /* @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTrailingFrames2(Nd4jBackend backend) { INDArray x = Nd4j.create(new int[] {10, 10}, 'f'); INDArray y = Nd4j.create(new int[] {10, 10}, 'c'); @@ -130,7 +137,7 @@ public class StackAggregatorTests extends BaseNd4jTest { @Test @Disabled - public void testScalarAggregator() { + public void testScalarAggregator(Nd4jBackend backend) { INDArray x = Nd4j.create(10); x.putScalar(0, 1.0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/HalfTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/HalfTests.java index f198db33d..535cc32ac 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/HalfTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/HalfTests.java @@ -25,9 +25,10 @@ import lombok.val; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -36,15 +37,10 @@ import org.nd4j.linalg.ops.transforms.Transforms; import static junit.framework.TestCase.assertTrue; @Slf4j -@RunWith(Parameterized.class) -public class HalfTests extends BaseNd4jTest { - private DataType initialType; - - public HalfTests(Nd4jBackend backend) { - super(backend); - } +public class HalfTests extends BaseNd4jTestWithBackends { + private DataType initialType = Nd4j.dataType(); @BeforeEach public void setUp() { if (!Nd4j.getExecutioner().getClass().getSimpleName().toLowerCase().contains("cuda")) @@ -63,7 +59,9 @@ public class HalfTests extends BaseNd4jTest { } @Test - public void testRandomNorman_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRandomNorman_1(Nd4jBackend backend) { val array = Nd4j.randn(new long[]{20, 30}); val sum = Transforms.abs(array).sumNumber().doubleValue(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomPerformanceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomPerformanceTests.java index 6e5c966b8..3208ece6e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomPerformanceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomPerformanceTests.java @@ -22,22 +22,19 @@ package org.nd4j.linalg.rng; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Disabled; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; -import org.nd4j.linalg.factory.Nd4jBackend; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; @Slf4j -@RunWith(Parameterized.class) -@Disabled -public class RandomPerformanceTests extends BaseNd4jTest { - public RandomPerformanceTests(Nd4jBackend backend) { - super(backend); - } +@Disabled +public class RandomPerformanceTests extends BaseNd4jTestWithBackends { + /* - @Test + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDropoutPerformance() throws Exception { for (int i = 0; i < 100; i++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index 34ed252ed..5cef24bfb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -28,9 +28,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -70,14 +71,11 @@ import java.util.concurrent.atomic.AtomicInteger; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class RandomTests extends BaseNd4jTest { + +public class RandomTests extends BaseNd4jTestWithBackends { private DataType initialType; - public RandomTests(Nd4jBackend backend) { - super(backend); - } @BeforeEach public void setUp() { @@ -91,21 +89,25 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testCrossBackendEquality1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCrossBackendEquality1(Nd4jBackend backend) { int[] shape = {12}; double mean = 0; double standardDeviation = 1.0; INDArray exp = Nd4j.create(new double[] {-0.832718168582558, 1.3312306172061867, -0.27101354040045766, 1.0368130323476494, -0.6257379511224601, 0.30653534119847814, 0.28250229228899343, -0.5464191486048424, 0.5182898732953277, 1.463107608378911, 0.5634855878214299, -1.4979616922031507}); Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.getExecutioner().exec(new GaussianDistribution( - Nd4j.createUninitialized(shape, Nd4j.order()), mean, standardDeviation), Nd4j.getRandom()); + Nd4j.createUninitialized(shape, Nd4j.order()), mean, standardDeviation), Nd4j.getRandom()); assertEquals(exp, arr); } @Test - public void testDistribution1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDistribution1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -127,7 +129,9 @@ public class RandomTests extends BaseNd4jTest { @Test - public void testDistribution2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDistribution2(Nd4jBackend backend) { val random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); val random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -153,7 +157,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testDistribution3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDistribution3(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); INDArray z1 = Nd4j.create(128); @@ -167,7 +173,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testDistribution4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDistribution4(Nd4jBackend backend) { for (int i = 0; i < 100; i++) { Nd4j.getRandom().setSeed(119); @@ -182,7 +190,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testDistribution5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDistribution5(Nd4jBackend backend) { for (int i = 0; i < 100; i++) { Nd4j.getRandom().setSeed(120); @@ -197,7 +207,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testDistribution6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDistribution6(Nd4jBackend backend) { for (int i = 0; i < 100; i++) { Nd4j.getRandom().setSeed(120); @@ -212,7 +224,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testLinspace1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLinspace1(Nd4jBackend backend) { INDArray z1 = Nd4j.linspace(1, 100, 200, DataType.DOUBLE); Linspace linspace = new Linspace((double) 1, (double) 100, 200, DataType.DOUBLE); @@ -224,7 +238,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testDropoutInverted1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDropoutInverted1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -249,7 +265,9 @@ public class RandomTests extends BaseNd4jTest { @Test - public void testDropout1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDropout1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -269,7 +287,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testAlphaDropout1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAlphaDropout1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -290,7 +310,9 @@ public class RandomTests extends BaseNd4jTest { @Test - public void testGaussianDistribution1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGaussianDistribution1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -318,7 +340,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testGaussianDistribution2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGaussianDistribution2(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random3 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -357,7 +381,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testGaussianDistribution3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGaussianDistribution3(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -388,7 +414,9 @@ public class RandomTests extends BaseNd4jTest { * @throws Exception */ @Test - public void testAndersonDarling() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAndersonDarling(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); INDArray z1 = Nd4j.create(1000); @@ -425,7 +453,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testStepOver1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStepOver1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); INDArray z0 = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(DataType.DOUBLE, 1000000), 0.0, 1.0)); @@ -449,14 +479,18 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testSum_119() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum_119(Nd4jBackend backend) { INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000); val sum = z2.sumNumber().doubleValue(); assertEquals(0.0, sum, 1e-5); } @Test - public void testLegacyDistribution1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLegacyDistribution1(Nd4jBackend backend) { NormalDistribution distribution = new NormalDistribution(new DefaultRandom(), 0.0, 1.0); INDArray z1 = distribution.sample(new int[] {1, 1000000}); @@ -465,7 +499,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testSetSeed1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSetSeed1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -504,7 +540,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testJavaSide1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJavaSide1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -522,7 +560,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testJavaSide2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJavaSide2(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -541,7 +581,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testJavaSide3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJavaSide3(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -566,7 +608,9 @@ public class RandomTests extends BaseNd4jTest { */ @Test - public void testJavaSide4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJavaSide4(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -599,7 +643,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testJavaSide5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJavaSide5(Nd4jBackend backend) { Nd4j.getRandom().setSeed(7); int length = 100; @@ -623,7 +669,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testBernoulliDistribution1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBernoulliDistribution1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -643,7 +691,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testBernoulliDistribution2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBernoulliDistribution2(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -667,7 +717,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testBernoulliDistribution3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBernoulliDistribution3(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -692,7 +744,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testBinomialDistribution1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBinomialDistribution1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -715,7 +769,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testBinomialDistribution2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBinomialDistribution2(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119); @@ -740,7 +796,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testMultithreading1() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultithreading1(Nd4jBackend backend) throws Exception { final AtomicInteger cnt = new AtomicInteger(0); final CopyOnWriteArrayList list = new CopyOnWriteArrayList<>(); @@ -751,18 +809,15 @@ public class RandomTests extends BaseNd4jTest { } for (int x = 0; x < threads.length; x++) { - threads[x] = new Thread(new Runnable() { - @Override - public void run() { - Random rnd = Nd4j.getRandom(); - rnd.setSeed(119); - float[] array = new float[10]; + threads[x] = new Thread(() -> { + Random rnd = Nd4j.getRandom(); + rnd.setSeed(119); + float[] array = new float[10]; - for (int e = 0; e < array.length; e++) { - array[e] = rnd.nextFloat(); - } - list.set(cnt.getAndIncrement(), array); + for (int e = 0; e < array.length; e++) { + array[e] = rnd.nextFloat(); } + list.set(cnt.getAndIncrement(), array); }); threads[x].start(); } @@ -781,6 +836,8 @@ public class RandomTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultithreading2() throws Exception { final AtomicInteger cnt = new AtomicInteger(0); @@ -792,17 +849,14 @@ public class RandomTests extends BaseNd4jTest { } for (int x = 0; x < threads.length; x++) { - threads[x] = new Thread(new Runnable() { - @Override - public void run() { - Random rnd = Nd4j.getRandom(); - rnd.setSeed(119); - INDArray array = Nd4j.getExecutioner().exec(new UniformDistribution(Nd4j.createUninitialized(25))); + threads[x] = new Thread(() -> { + Random rnd = Nd4j.getRandom(); + rnd.setSeed(119); + INDArray array = Nd4j.getExecutioner().exec(new UniformDistribution(Nd4j.createUninitialized(25))); - Nd4j.getExecutioner().commit(); + Nd4j.getExecutioner().commit(); - list.set(cnt.getAndIncrement(), array); - } + list.set(cnt.getAndIncrement(), array); }); threads[x].start(); } @@ -821,7 +875,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testStepOver3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStepOver3(Nd4jBackend backend) { Random random = Nd4j.getRandomFactory().getNewRandomInstance(119); if (random instanceof NativeRandom) { NativeRandom rng = (NativeRandom) random; @@ -848,7 +904,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testStepOver4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStepOver4(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119, 100000); Random random2 = Nd4j.getRandomFactory().getNewRandomInstance(119, 100000); @@ -861,7 +919,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testSignatures1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSignatures1(Nd4jBackend backend) { for (int x = 0; x < 100; x++) { INDArray z1 = Nd4j.randn(5325235, new long[]{128, 1}); @@ -872,7 +932,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testChoice1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testChoice1(Nd4jBackend backend) { INDArray source = Nd4j.create(new double[] {1, 2, 3, 4, 5}); INDArray probs = Nd4j.create(new double[] {0.0, 0.0, 1.0, 0.0, 0.0}); INDArray exp = Nd4j.create(5).assign(3.0); @@ -882,7 +944,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testChoice2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testChoice2(Nd4jBackend backend) { INDArray source = Nd4j.create(new double[] {1, 2, 3, 4, 5}); INDArray probs = Nd4j.create(new double[] {0.0, 0.0, 0.0, 0.0, 0.0}); INDArray exp = Nd4j.create(5).assign(5.0); @@ -893,6 +957,8 @@ public class RandomTests extends BaseNd4jTest { @Disabled @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDeallocation1() throws Exception { while (true) { @@ -905,348 +971,350 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void someTest() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void someTest(Nd4jBackend backend) { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); INDArray x = Nd4j.create(new double[] {-0.5753774207320429, 1.0614372269091394, 0.4522970978070401, - -0.5752887679689271, 1.0636465735137173, 0.4544011796073467, -0.576361407698785, - 1.0656790105069853, 0.4552935317796974, -0.5760602684016433, 1.0658617022858135, - 0.4557330858969331, -0.5757970093448411, 1.0622487939115577, 0.45266130626880363, - -0.5752622961957029, 1.0582596824316828, 0.44949025343112814, -0.5771479956928688, - 1.0665372965638613, 0.4553688166885955, -0.5753088931923759, 1.0620227840548335, - 0.45289545873086556, -0.576588580700202, 1.0682150986638697, 0.457411469249719, - -0.5747325473572189, 1.0626318659592515, 0.4539743754957771, -0.5745380761623263, - 1.0581714324564084, 0.4500640145455051, -0.5756600950978087, 1.0634216668548728, - 0.4538595118971328, -0.5751140573519833, 1.0640115397234116, 0.45489343676357286, - -0.5772284666676437, 1.0696940198418068, 0.4581879096117204, -0.5744147982066905, - 1.0554839926243997, 0.4477135176681925, -0.5754198385793243, 1.0558429782980523, - 0.44713394665660644, -0.5761545677071064, 1.0598241808807554, 0.45011696447560207, - -0.5758387163599189, 1.0619667903192647, 0.4523652688352249, -0.5737984521578438, - 1.0551267152966937, 0.4479433219105848, -0.5759974232799963, 1.061302689492133, - 0.4516134441303072, -0.5736901589111626, 1.0576251048845364, 0.4503299444045488, - -0.5763311372167914, 1.06192192215954, 0.45187907799834365, -0.5778442414543, - 1.0674079152998242, 0.45553705763054314, -0.5758254570690142, 1.0620200161144016, - 0.4524260129848761, -0.5749775837304827, 1.062224210147449, 0.45337944519367585, - -0.574541903754345, 1.0619442384090578, 0.45351676811211955, -0.5760078457119082, - 1.062690890233097, 0.4528757342573996, -0.5748606750551666, 1.060141033285612, - 0.4515767478829046, -0.5749561834487571, 1.0606232394644224, 0.45193216220783466, - -0.5756803380730748, 1.064483756604441, 0.4548141798773699, -0.5752565746574122, - 1.0636651281176792, 0.4544472759986484, -0.5750760910978936, 1.0594989813795266, - 0.45079386382003334, -0.5751674161305798, 1.0590858567198587, 0.45033285969135406, - -0.5750886065307328, 1.0572011798927974, 0.4486775685374512, -0.5747325473572189, - 1.0626318659592515, 0.4539743754957771, -0.5757243201088236, 1.0633839362120128, - 0.45376689590426994, -0.5744411030524335, 1.0582391680513001, 0.45021371788814785, - -0.5747325473572189, 1.0626318659592515, 0.4539743754957771, -0.5769510974701872, - 1.0685324074495908, 0.4573744807836674, -0.5750191442942153, 1.0611238707219008, - 0.45233387445404916, -0.5763530480319555, 1.0632592080003551, 0.4530843416356724, - -0.5761681009423941, 1.0687223794712288, 0.4582562437719459, -0.5772202009540097, - 1.0683672322728441, 0.4569799298001917, -0.5770651807597004, 1.0636720905704742, - 0.4528188972040562, -0.5755594444325524, 1.0602552587289935, 0.4510497867771471, - -0.5760405012467995, 1.0650797166475576, 0.4550345871790212, -0.5753138307789047, - 1.0603836033532072, 0.451389365910235, -0.5764219486333497, 1.066178407227334, - 0.4556963003853961, -0.5748294633718319, 1.059070222875785, 0.450624005391455, - -0.5754032559272689, 1.062504307475741, 0.453251283125361, 0.357808280229093, - -0.17304804748832744, 0.1648877578656923, 0.3550956268779401, -0.16638470955750134, - 0.16854004156835015, 0.35761790317790293, -0.17225833018768533, 0.1654391291304103, - 0.3536090968875379, -0.1570909141136799, 0.17571031393597503, 0.3561854268639926, - -0.167380791258639, 0.16861259032124698, 0.3546448721372181, -0.161229935301283, - 0.17285482935309807, 0.354628589295547, -0.16574588493263773, 0.1687031152037963, - 0.3515608583761638, -0.15075008903410433, 0.17966769737990534, 0.35735084527857575, - -0.1696182518386006, 0.1676162794872508, 0.35146079433904887, -0.15372713783620343, - 0.17685002025939964, 0.3528734834345405, -0.1521597664861848, 0.17956276341866134, - 0.3532410649497478, -0.160680048791368, 0.1720897037995631, 0.356682698566458, - -0.16328251379445335, 0.17281643565506308, 0.3556302932619103, -0.16500416366377244, - 0.17028801230489224, 0.35211485765711686, -0.15678608646411626, 0.17463895406650265, - 0.35637497011042096, -0.1691665602108546, 0.16714799681616294, 0.35308078531675746, - -0.1592600519004829, 0.173245669482832, 0.3556196874799506, -0.16224708681088748, - 0.17280414441250597, 0.3559475841193771, -0.16396311971736327, 0.17152848950991376, - 0.35435929634532026, -0.15891041774418582, 0.17472158068918403, 0.3528490359864511, - -0.16132798573712082, 0.1711417922247098, 0.35462901944485786, -0.16272899207088296, - 0.17146723613971174, 0.3567480914698187, -0.16665684870871977, 0.16978436312547981, - 0.35677871524326865, -0.16619978521411394, 0.17023075253187472, 0.35606103185316756, - -0.16664741773206532, 0.16917198549729348, 0.3562273106630626, -0.16822741271934818, - 0.1678748703769742, 0.35803810004503234, -0.17145759936952631, 0.16655247328612868, - 0.3563871886834647, -0.16952991173201867, 0.1668261798007235, 0.35436973044992964, - -0.1626885508561808, 0.17126991846165585, 0.354059661856123, -0.15883963375895938, - 0.17451559223248628, 0.35397652790453105, -0.15754392604138207, 0.1756274285801798, - 0.35422920502812466, -0.15772901356550117, 0.1756862615390695, 0.35416424088944914, - -0.16022172948512917, 0.17334400078403028, 0.3555600143057507, -0.1643372584279808, - 0.17083543107967486, 0.3525087034842565, -0.1575072041681293, 0.17433433660001676, - 0.3531659556069536, -0.1624191446662591, 0.17042865350793346, 0.3565696507307317, - -0.1697220040407826, 0.166815130002541, 0.3568664974596232, -0.16577658963578037, - 0.1706977802149158, 0.35313668277505816, -0.15886990989683572, 0.17365359737044656, - 0.3533245352115322, -0.15723031113817726, 0.17532540564635665, 0.35460862238876345, - -0.1595238276259829, 0.17438500473685614, 0.35525250874776443, -0.16466741223783185, - 0.17025503480284157, 0.3545409063719635, -0.16055812395314287, 0.17337629382343148, - 0.35198952012701995, -0.15156979252918573, 0.17930423619280544, 0.3537953559292405, - -0.15906206241879808, 0.17407292855724904, 0.35415180834842913, -0.1607628482146717, - 0.1728370522185283, 0.3537998855935737, -0.1600845565243993, 0.1731403306802763, - 0.3554810273775851, -0.16489175524215102, 0.17025607008052857, 0.3508232195628162, - -0.15082599073411826, 0.17893143035496875, 0.35370792374178356, -0.15961008691395126, - 0.17349186328292782, 0.05450698542491758, -0.41874678698827594, -0.3343403087067353, - 0.05498792881564898, -0.41460440299356255, -0.33011081679631604, 0.059046779421456655, - -0.42765937881362637, -0.3384015915928204, 0.057799646609788376, -0.4216980629472357, - -0.3340677702649465, 0.05660348398009795, -0.42152485671613177, -0.3349902821139396, - 0.062105535400888166, -0.4346085458257504, -0.34200288508621907, 0.05234240369292872, - -0.4055153621656568, -0.32417570593377165, 0.062317826890744256, -0.43305655048852, - -0.3403892391519301, 0.05999457207577438, -0.4256813340236285, -0.3357328454874602, - 0.05678917347058686, -0.42675689269642103, -0.3396154345679126, 0.05573207104665189, - -0.42026752129437106, -0.33462610511478547, 0.05714994401613468, -0.4205474351785073, - -0.3336009477907372, 0.05726741118080793, -0.4235566120776033, -0.33625143560385046, - 0.05425935432791021, -0.41257249501421506, -0.3289079566747622, 0.052303907040540346, - -0.41140905979351317, -0.3296096337421601, 0.05435673726351468, -0.41816551767450155, - -0.33394362207267464, 0.057990097876612606, -0.4230779753541936, -0.3351597436911609, - 0.06092405835204879, -0.43557367866707497, -0.3439549390223794, 0.06258076523991336, - -0.4348580034399873, -0.34180186041642535, 0.058062574332262445, -0.41817120503133454, - -0.3305992122298355, 0.056667222489482326, -0.4238987396752845, -0.33710735037338646, - 0.05331470690325209, -0.4115191773603726, -0.32879687237030414, 0.06346989358689988, - -0.4364823047245409, -0.34248619701107236, 0.05644334203874793, -0.4187549773420934, - -0.3325975840580061, 0.057027209491249405, -0.4237123787110853, -0.33661124389650665, - 0.060880554331858905, -0.43118060018960624, -0.3399698253230048, 0.055782276874872014, - -0.41756220328628996, -0.3321024223397344, 0.055446756449796616, -0.41723881862577716, - -0.3321094434159544, 0.056669295076943016, -0.4204495146444979, -0.3339456915975489, - 0.06175751688116003, -0.4316649644441389, -0.3396208783911239, 0.06173341588037583, - -0.43230523620229205, -0.3402292064686406, 0.061875407606504736, -0.4375444226515741, - -0.3449004067573839, 0.05614720683299703, -0.41976126084969595, -0.33378709562366216, - 0.05830426102392656, -0.4215960573784757, -0.3335182151970676, 0.05970588243873762, - -0.42240282576500415, -0.33299039110293827, 0.0601036415102731, -0.43192922004298595, - -0.341357858629765, 0.053969290626186925, -0.4179579756815318, -0.3341036998452465, - 0.057561584042216396, -0.4222868516155249, -0.3348223303093653, 0.05493041051018744, - -0.4159784070400405, -0.3314215115884726, 0.05717139029405299, -0.4240689677351035, - -0.3368075882948835, 0.05549340668146566, -0.42106549309448166, -0.3355728387667392, - 0.05541833273044943, -0.4214816855580636, -0.3360219642916261, 0.05498792881564898, - -0.41460440299356255, -0.33011081679631604, 0.05685500116978513, -0.42385292467441293, - -0.336895651128018, 0.054911515245869534, -0.4208844109998573, -0.33593291020280946, - 0.05523742901993513, -0.4201363809538045, -0.3349530647612928, 0.05645313933816329, - -0.4183606591328087, -0.33222749927008216, 0.05624807825529575, -0.42052328095315095, - -0.33439399594116775, 0.05375363814021924, -0.4170276113599561, -0.33344632974645483, - 0.05534618656451573, -0.41631425766460334, -0.33135336920555164}, new int[] {150, 3}, 'c'); + -0.5752887679689271, 1.0636465735137173, 0.4544011796073467, -0.576361407698785, + 1.0656790105069853, 0.4552935317796974, -0.5760602684016433, 1.0658617022858135, + 0.4557330858969331, -0.5757970093448411, 1.0622487939115577, 0.45266130626880363, + -0.5752622961957029, 1.0582596824316828, 0.44949025343112814, -0.5771479956928688, + 1.0665372965638613, 0.4553688166885955, -0.5753088931923759, 1.0620227840548335, + 0.45289545873086556, -0.576588580700202, 1.0682150986638697, 0.457411469249719, + -0.5747325473572189, 1.0626318659592515, 0.4539743754957771, -0.5745380761623263, + 1.0581714324564084, 0.4500640145455051, -0.5756600950978087, 1.0634216668548728, + 0.4538595118971328, -0.5751140573519833, 1.0640115397234116, 0.45489343676357286, + -0.5772284666676437, 1.0696940198418068, 0.4581879096117204, -0.5744147982066905, + 1.0554839926243997, 0.4477135176681925, -0.5754198385793243, 1.0558429782980523, + 0.44713394665660644, -0.5761545677071064, 1.0598241808807554, 0.45011696447560207, + -0.5758387163599189, 1.0619667903192647, 0.4523652688352249, -0.5737984521578438, + 1.0551267152966937, 0.4479433219105848, -0.5759974232799963, 1.061302689492133, + 0.4516134441303072, -0.5736901589111626, 1.0576251048845364, 0.4503299444045488, + -0.5763311372167914, 1.06192192215954, 0.45187907799834365, -0.5778442414543, + 1.0674079152998242, 0.45553705763054314, -0.5758254570690142, 1.0620200161144016, + 0.4524260129848761, -0.5749775837304827, 1.062224210147449, 0.45337944519367585, + -0.574541903754345, 1.0619442384090578, 0.45351676811211955, -0.5760078457119082, + 1.062690890233097, 0.4528757342573996, -0.5748606750551666, 1.060141033285612, + 0.4515767478829046, -0.5749561834487571, 1.0606232394644224, 0.45193216220783466, + -0.5756803380730748, 1.064483756604441, 0.4548141798773699, -0.5752565746574122, + 1.0636651281176792, 0.4544472759986484, -0.5750760910978936, 1.0594989813795266, + 0.45079386382003334, -0.5751674161305798, 1.0590858567198587, 0.45033285969135406, + -0.5750886065307328, 1.0572011798927974, 0.4486775685374512, -0.5747325473572189, + 1.0626318659592515, 0.4539743754957771, -0.5757243201088236, 1.0633839362120128, + 0.45376689590426994, -0.5744411030524335, 1.0582391680513001, 0.45021371788814785, + -0.5747325473572189, 1.0626318659592515, 0.4539743754957771, -0.5769510974701872, + 1.0685324074495908, 0.4573744807836674, -0.5750191442942153, 1.0611238707219008, + 0.45233387445404916, -0.5763530480319555, 1.0632592080003551, 0.4530843416356724, + -0.5761681009423941, 1.0687223794712288, 0.4582562437719459, -0.5772202009540097, + 1.0683672322728441, 0.4569799298001917, -0.5770651807597004, 1.0636720905704742, + 0.4528188972040562, -0.5755594444325524, 1.0602552587289935, 0.4510497867771471, + -0.5760405012467995, 1.0650797166475576, 0.4550345871790212, -0.5753138307789047, + 1.0603836033532072, 0.451389365910235, -0.5764219486333497, 1.066178407227334, + 0.4556963003853961, -0.5748294633718319, 1.059070222875785, 0.450624005391455, + -0.5754032559272689, 1.062504307475741, 0.453251283125361, 0.357808280229093, + -0.17304804748832744, 0.1648877578656923, 0.3550956268779401, -0.16638470955750134, + 0.16854004156835015, 0.35761790317790293, -0.17225833018768533, 0.1654391291304103, + 0.3536090968875379, -0.1570909141136799, 0.17571031393597503, 0.3561854268639926, + -0.167380791258639, 0.16861259032124698, 0.3546448721372181, -0.161229935301283, + 0.17285482935309807, 0.354628589295547, -0.16574588493263773, 0.1687031152037963, + 0.3515608583761638, -0.15075008903410433, 0.17966769737990534, 0.35735084527857575, + -0.1696182518386006, 0.1676162794872508, 0.35146079433904887, -0.15372713783620343, + 0.17685002025939964, 0.3528734834345405, -0.1521597664861848, 0.17956276341866134, + 0.3532410649497478, -0.160680048791368, 0.1720897037995631, 0.356682698566458, + -0.16328251379445335, 0.17281643565506308, 0.3556302932619103, -0.16500416366377244, + 0.17028801230489224, 0.35211485765711686, -0.15678608646411626, 0.17463895406650265, + 0.35637497011042096, -0.1691665602108546, 0.16714799681616294, 0.35308078531675746, + -0.1592600519004829, 0.173245669482832, 0.3556196874799506, -0.16224708681088748, + 0.17280414441250597, 0.3559475841193771, -0.16396311971736327, 0.17152848950991376, + 0.35435929634532026, -0.15891041774418582, 0.17472158068918403, 0.3528490359864511, + -0.16132798573712082, 0.1711417922247098, 0.35462901944485786, -0.16272899207088296, + 0.17146723613971174, 0.3567480914698187, -0.16665684870871977, 0.16978436312547981, + 0.35677871524326865, -0.16619978521411394, 0.17023075253187472, 0.35606103185316756, + -0.16664741773206532, 0.16917198549729348, 0.3562273106630626, -0.16822741271934818, + 0.1678748703769742, 0.35803810004503234, -0.17145759936952631, 0.16655247328612868, + 0.3563871886834647, -0.16952991173201867, 0.1668261798007235, 0.35436973044992964, + -0.1626885508561808, 0.17126991846165585, 0.354059661856123, -0.15883963375895938, + 0.17451559223248628, 0.35397652790453105, -0.15754392604138207, 0.1756274285801798, + 0.35422920502812466, -0.15772901356550117, 0.1756862615390695, 0.35416424088944914, + -0.16022172948512917, 0.17334400078403028, 0.3555600143057507, -0.1643372584279808, + 0.17083543107967486, 0.3525087034842565, -0.1575072041681293, 0.17433433660001676, + 0.3531659556069536, -0.1624191446662591, 0.17042865350793346, 0.3565696507307317, + -0.1697220040407826, 0.166815130002541, 0.3568664974596232, -0.16577658963578037, + 0.1706977802149158, 0.35313668277505816, -0.15886990989683572, 0.17365359737044656, + 0.3533245352115322, -0.15723031113817726, 0.17532540564635665, 0.35460862238876345, + -0.1595238276259829, 0.17438500473685614, 0.35525250874776443, -0.16466741223783185, + 0.17025503480284157, 0.3545409063719635, -0.16055812395314287, 0.17337629382343148, + 0.35198952012701995, -0.15156979252918573, 0.17930423619280544, 0.3537953559292405, + -0.15906206241879808, 0.17407292855724904, 0.35415180834842913, -0.1607628482146717, + 0.1728370522185283, 0.3537998855935737, -0.1600845565243993, 0.1731403306802763, + 0.3554810273775851, -0.16489175524215102, 0.17025607008052857, 0.3508232195628162, + -0.15082599073411826, 0.17893143035496875, 0.35370792374178356, -0.15961008691395126, + 0.17349186328292782, 0.05450698542491758, -0.41874678698827594, -0.3343403087067353, + 0.05498792881564898, -0.41460440299356255, -0.33011081679631604, 0.059046779421456655, + -0.42765937881362637, -0.3384015915928204, 0.057799646609788376, -0.4216980629472357, + -0.3340677702649465, 0.05660348398009795, -0.42152485671613177, -0.3349902821139396, + 0.062105535400888166, -0.4346085458257504, -0.34200288508621907, 0.05234240369292872, + -0.4055153621656568, -0.32417570593377165, 0.062317826890744256, -0.43305655048852, + -0.3403892391519301, 0.05999457207577438, -0.4256813340236285, -0.3357328454874602, + 0.05678917347058686, -0.42675689269642103, -0.3396154345679126, 0.05573207104665189, + -0.42026752129437106, -0.33462610511478547, 0.05714994401613468, -0.4205474351785073, + -0.3336009477907372, 0.05726741118080793, -0.4235566120776033, -0.33625143560385046, + 0.05425935432791021, -0.41257249501421506, -0.3289079566747622, 0.052303907040540346, + -0.41140905979351317, -0.3296096337421601, 0.05435673726351468, -0.41816551767450155, + -0.33394362207267464, 0.057990097876612606, -0.4230779753541936, -0.3351597436911609, + 0.06092405835204879, -0.43557367866707497, -0.3439549390223794, 0.06258076523991336, + -0.4348580034399873, -0.34180186041642535, 0.058062574332262445, -0.41817120503133454, + -0.3305992122298355, 0.056667222489482326, -0.4238987396752845, -0.33710735037338646, + 0.05331470690325209, -0.4115191773603726, -0.32879687237030414, 0.06346989358689988, + -0.4364823047245409, -0.34248619701107236, 0.05644334203874793, -0.4187549773420934, + -0.3325975840580061, 0.057027209491249405, -0.4237123787110853, -0.33661124389650665, + 0.060880554331858905, -0.43118060018960624, -0.3399698253230048, 0.055782276874872014, + -0.41756220328628996, -0.3321024223397344, 0.055446756449796616, -0.41723881862577716, + -0.3321094434159544, 0.056669295076943016, -0.4204495146444979, -0.3339456915975489, + 0.06175751688116003, -0.4316649644441389, -0.3396208783911239, 0.06173341588037583, + -0.43230523620229205, -0.3402292064686406, 0.061875407606504736, -0.4375444226515741, + -0.3449004067573839, 0.05614720683299703, -0.41976126084969595, -0.33378709562366216, + 0.05830426102392656, -0.4215960573784757, -0.3335182151970676, 0.05970588243873762, + -0.42240282576500415, -0.33299039110293827, 0.0601036415102731, -0.43192922004298595, + -0.341357858629765, 0.053969290626186925, -0.4179579756815318, -0.3341036998452465, + 0.057561584042216396, -0.4222868516155249, -0.3348223303093653, 0.05493041051018744, + -0.4159784070400405, -0.3314215115884726, 0.05717139029405299, -0.4240689677351035, + -0.3368075882948835, 0.05549340668146566, -0.42106549309448166, -0.3355728387667392, + 0.05541833273044943, -0.4214816855580636, -0.3360219642916261, 0.05498792881564898, + -0.41460440299356255, -0.33011081679631604, 0.05685500116978513, -0.42385292467441293, + -0.336895651128018, 0.054911515245869534, -0.4208844109998573, -0.33593291020280946, + 0.05523742901993513, -0.4201363809538045, -0.3349530647612928, 0.05645313933816329, + -0.4183606591328087, -0.33222749927008216, 0.05624807825529575, -0.42052328095315095, + -0.33439399594116775, 0.05375363814021924, -0.4170276113599561, -0.33344632974645483, + 0.05534618656451573, -0.41631425766460334, -0.33135336920555164}, new int[] {150, 3}, 'c'); INDArray y = Nd4j.create(new double[] {0.2429357202832011, 0.24691828776293456, 0.24583756032730986, - 0.24705968192172242, 0.24232827188842557, 0.23861718711877997, 0.24465104629395537, - 0.2442738260690249, 0.24817946695739254, 0.24674303971330414, 0.2406225888573061, - 0.24498353214674504, 0.24727346570657688, 0.24750033019519094, 0.23490644039834577, - 0.23069196589254812, 0.23707179721263205, 0.2426418870958991, 0.23908522418834394, - 0.24018555003273495, 0.24389367527050315, 0.24082645976621256, 0.24213120625453619, - 0.24456405220771912, 0.24572256827067002, 0.24714763826639866, 0.2440167016142005, - 0.24298275912681128, 0.24351951667026728, 0.24651027782106275, 0.24692250660853385, - 0.24274136220433426, 0.23744386161723494, 0.23442898757139466, 0.24674303971330414, - 0.2449728421653453, 0.2415619815480595, 0.24674303971330414, 0.247534800371069, - 0.24404181458275354, 0.24259378898938533, 0.24971978740576964, 0.2464087697036708, - 0.24262430969703594, 0.24122273758752688, 0.2468932971669087, 0.24086424499542164, - 0.24625774119619978, 0.24091022329137776, 0.2447601281813779, 0.24650971563804278, - 0.24694511135813532, 0.24744726876889875, 0.2499998745489847, 0.2488199078775225, - 0.2495902563502858, 0.24675189476120526, 0.24998568760015533, 0.24859167567513393, - 0.24959567465133717, 0.24963361855789493, 0.24817253327065975, 0.249994140708285, - 0.24911602549653417, 0.24836829906001545, 0.24699702031098014, 0.24893611893905115, - 0.24965316728681577, 0.2499998535029291, 0.24988790472862127, 0.24775308931283793, - 0.24873108439991887, 0.2498966240270357, 0.24955939151129433, 0.2484053958352789, - 0.24768482160702535, 0.2488742705575876, 0.24806442277140775, 0.2488691246208544, - 0.24947483938321652, 0.24996777647841242, 0.24996708063398507, 0.24933140745588728, - 0.24980398107656687, 0.2491285118494222, 0.24625987230771443, 0.24738925580304405, - 0.24997065883924532, 0.2486746004152257, 0.2498843535089784, 0.24995370354353233, - 0.24864990162099362, 0.2496455019084883, 0.24999998132131074, 0.24963836089792163, - 0.24882032311579313, 0.2490425997813645, 0.24864262256994435, 0.24962396682576615, - 0.24923850229014768, 0.24758302979788205, 0.2497364221125665, 0.24833712214381948, - 0.24945237443576807, 0.2487672378102696, 0.24872684754912008, 0.24999957029328843, - 0.24932200515812822, 0.24997350874499308, 0.244153258352151, 0.2470095569300082, - 0.2495218586932054, 0.24811062315859542, 0.249937451895014, 0.24908117922763642, - 0.2469981453066533, 0.24887643643724272, 0.24388574974855176, 0.2497890295896134, - 0.24985753528646346, 0.24695464048531734, 0.2494344204235447, 0.24945925714694822, - 0.24932962141151696, 0.24711291849801428, 0.24792873650284003, 0.24901637934722654, - 0.24853055677109265, 0.2493702155100004, 0.24875396650445886, 0.24922420532388534, - 0.24308868874709608, 0.24927788200579612, 0.2495128512715364, 0.2499941533196743, - 0.24750947275677102, 0.24642116577310766, 0.2486200299475723, 0.24850849338968475, - 0.24729748244118202, 0.24744445799632175, 0.24628385050522522, 0.2497364221125665, - 0.24749963780305764, 0.2463077044861883, 0.24740549843131626, 0.24976137204060936, - 0.24818886413733424, 0.24637984563882515, 0.24900312437125263, 0.24725246251464808, - 0.2487629429430977, 0.24865171031370353, 0.24898122102519538, 0.24717579119456606, - 0.2460980895111744, 0.24863419281317758, 0.247722109456114, 0.249483330285878, - 0.24833150510043103, 0.24597530718656643, 0.24808523436616853, 0.24868160690258487, - 0.24928092765244053, 0.24371474005808394, 0.2433465092843804, 0.24605554602979549, - 0.24756485929873812, 0.24528840047309738, 0.24672264761923993, 0.24694122129988064, - 0.24734547447905952, 0.24790888678135556, 0.24858706085548732, 0.24810764262100565, - 0.24863482616236307, 0.24827880810787284, 0.24705502393438042, 0.24732810186246684, - 0.24867054596532764, 0.2487232960074739, 0.24756546852978628, 0.24464019605672277, - 0.24381663076095264, 0.24833150510043103, 0.24818162318958523, 0.24637168326866488, - 0.24833150510043103, 0.249349601679655, 0.24753319754810257, 0.24774444170583876, - 0.24996939860971737, 0.24904696608261323, 0.2485561286364399, 0.24710085120475833, - 0.24909198676990582, 0.24637328290411886, 0.2487969512643165, 0.2462161750985027, - 0.24796208770766182, 0.2482717242774289, 0.24927296051372536, 0.24886483945735274, - 0.24993221481965383, 0.24967300752206745, 0.24992406337114595, 0.24940315381994102, - 0.24994552914860874, 0.24916401420686277, 0.24995963968339815, 0.24973233672772352, - 0.2498543255582006, 0.24993325688497398, 0.24974525038828596, 0.24989546071480015, - 0.2488904029901035, 0.24996473851724282, 0.24968939120009442, 0.24998514600584043, - 0.249970842490889, 0.24994173055054456, 0.24971192503658038, 0.24996592891614833, - 0.24962101269911038, 0.24936160593652545, 0.2491775995760193, 0.24927805303451178, - 0.24956603508121747, 0.24987730415488038, 0.24981789385147107, 0.24999930237271728, - 0.24998381258446437, 0.24986548776192077, 0.24999570010002634, 0.24999591182398645, - 0.24954087006361633, 0.24910219683524124, 0.24994978142288657, 0.24984413117714582, - 0.2499918867206099, 0.24999323357114836, 0.24964894224451797, 0.24992220590291733, - 0.2499341504594762, 0.2499813141457409, 0.24969535176901464, 0.2498639405365019, - 0.24954381392197553, 0.24998693285959547, 0.24991793303705503, 0.24997894264010345, - 0.24987574787239242, 0.24976012810712447, 0.24995513300895975, 0.24999947057326638, - 0.2493918776687613, 0.24924872331895642, 0.24934183397707277, 0.24998712583879595, - 0.24955837700116507, 0.2498331267576389, 0.24999981660905926, 0.24990076371927936, - 0.24953797044297324, 0.2494265251092823, 0.2499977970977972, 0.24982068183227324, - 0.24798657706157284, 0.24991157204598433, 0.249932740378162, 0.24988056441711978, - 0.24976599942821817, 0.24941195957117365, 0.24999632523994095, 0.24974702472240795, - 0.24897521663897507, 0.24999250490090896, 0.24996305238911065, 0.2499888335767531, - 0.24891106929086498, 0.24951603791242954, 0.24695890528687767, 0.24995848440320612, - 0.24981122350168533, 0.2499494301459412, 0.24956512364416267, 0.2499975951279776, - 0.2498004863662123, 0.24998230392115176, 0.24978402532449248, 0.2499982957455684, - 0.2499233623408743, 0.24987574787239242, 0.24992258224573397, 0.2499877309539672, - 0.24999575770242835, 0.2499556931106015, 0.24994451980664456, 0.2499917805496643, - 0.2499960409062574, 0.2488906858411297, 0.24927221432053384, 0.24959263428325498, - 0.2496480416964974, 0.2490800126913869, 0.24700417829410917, 0.24947297522916015, - 0.24903554141985265, 0.24986212683419776, 0.2494710235394941, 0.24812395570517215, - 0.24933303959179373, 0.24963609512126977, 0.2499981104347387, 0.2471864446636986, - 0.24611234486702, 0.2473548098091017, 0.24854393204705955, 0.24653404287294756, - 0.24846732083974551, 0.24799207032097062, 0.24806934707439765, 0.24977920077440272, - 0.24747851578074453, 0.2491977800330258, 0.24899684819566759, 0.2482703871176769, - 0.24861888752204364, 0.24868382859250693, 0.2494854381282655, 0.24934135158120313, - 0.24721040655966006, 0.24893513851554547, 0.24790515403657404, 0.2494710235394941, - 0.2491904936736734, 0.24801381586026497, 0.2494710235394941, 0.2498798518999066, - 0.24883828402114708, 0.24882333106825177, 0.2496579340850042, 0.24987747585242445, - 0.2473671119886922, 0.24776221560479675, 0.2491701169747418, 0.24876589576359523, - 0.24967670262514718, 0.24837643713184637, 0.24908976954024437, 0.22109653049902084, - 0.22548404062580826, 0.2200516152327322, 0.23608366774470158, 0.22444171845768218, - 0.23365223802969015, 0.22447052628573314, 0.24346374266820078, 0.2263252080678485, - 0.23714899298172334, 0.2427377880302641, 0.23030753471526916, 0.23596746213525796, - 0.22891052329406103, 0.23607820782805589, 0.2245286616708374, 0.2319875048139398, - 0.2370609309083055, 0.22732603437985477, 0.23769735050651258, 0.2249363649166962, - 0.2316955228651909, 0.22550619621604487, 0.23145195006251598, 0.22866901371862267, - 0.22540992944493907, 0.22272184895874236, 0.21868396394543543, 0.2288664303844395, - 0.2387563329217531, 0.23851371439846614, 0.2396548175328528, 0.235265249196123, - 0.22619919934323457, 0.2334753125319801, 0.22747857709673994, 0.22237888950209675, - 0.2293405393777329, 0.23512174304783892, 0.23605961692837263, 0.2363884920872055, - 0.22911784265366114, 0.23508631539486008, 0.24298537417416582, 0.23496693141570144, - 0.2353505896619777, 0.23423355787376254, 0.23026861594819384, 0.24205955937189744, - 0.23444214017181403, 0.20711228404659351, 0.22376613724792477, 0.20579434619205228, - 0.2194026188338758, 0.2106294523516182, 0.19834879650946866, 0.23482430646847627, - 0.20760503194119267, 0.2151216557564706, 0.197085900182565, 0.21569731990066776, - 0.217758419139472, 0.21004598240061953, 0.22350675479473048, 0.21633307169842297, - 0.21137335744670882, 0.2177596292210227, 0.19501204554661022, 0.19284473370296384, - 0.22784579330881613, 0.20506242624967894, 0.22457929209637387, 0.19874826014581815, - 0.22121053191398565, 0.2104514538390277, 0.2094343294766834, 0.22234971134381387, - 0.22296768210934845, 0.21382831666860633, 0.21326665887229082, 0.20547154049147015, - 0.1972613998245205, 0.21223025778509932, 0.22498680501540697, 0.22691830656416243, - 0.19521869814219142, 0.20988589092182444, 0.21868921978115488, 0.2240953476388927, - 0.20928467448290794, 0.20578906266338623, 0.20681378450968757, 0.22376613724792477, - 0.20553847889431168, 0.20376157669628492, 0.20862506432378858, 0.2195097946800901, - 0.21546464912779167, 0.2130691837611387, 0.22424979277256304}, new int[] {150, 3}, 'f'); + 0.24705968192172242, 0.24232827188842557, 0.23861718711877997, 0.24465104629395537, + 0.2442738260690249, 0.24817946695739254, 0.24674303971330414, 0.2406225888573061, + 0.24498353214674504, 0.24727346570657688, 0.24750033019519094, 0.23490644039834577, + 0.23069196589254812, 0.23707179721263205, 0.2426418870958991, 0.23908522418834394, + 0.24018555003273495, 0.24389367527050315, 0.24082645976621256, 0.24213120625453619, + 0.24456405220771912, 0.24572256827067002, 0.24714763826639866, 0.2440167016142005, + 0.24298275912681128, 0.24351951667026728, 0.24651027782106275, 0.24692250660853385, + 0.24274136220433426, 0.23744386161723494, 0.23442898757139466, 0.24674303971330414, + 0.2449728421653453, 0.2415619815480595, 0.24674303971330414, 0.247534800371069, + 0.24404181458275354, 0.24259378898938533, 0.24971978740576964, 0.2464087697036708, + 0.24262430969703594, 0.24122273758752688, 0.2468932971669087, 0.24086424499542164, + 0.24625774119619978, 0.24091022329137776, 0.2447601281813779, 0.24650971563804278, + 0.24694511135813532, 0.24744726876889875, 0.2499998745489847, 0.2488199078775225, + 0.2495902563502858, 0.24675189476120526, 0.24998568760015533, 0.24859167567513393, + 0.24959567465133717, 0.24963361855789493, 0.24817253327065975, 0.249994140708285, + 0.24911602549653417, 0.24836829906001545, 0.24699702031098014, 0.24893611893905115, + 0.24965316728681577, 0.2499998535029291, 0.24988790472862127, 0.24775308931283793, + 0.24873108439991887, 0.2498966240270357, 0.24955939151129433, 0.2484053958352789, + 0.24768482160702535, 0.2488742705575876, 0.24806442277140775, 0.2488691246208544, + 0.24947483938321652, 0.24996777647841242, 0.24996708063398507, 0.24933140745588728, + 0.24980398107656687, 0.2491285118494222, 0.24625987230771443, 0.24738925580304405, + 0.24997065883924532, 0.2486746004152257, 0.2498843535089784, 0.24995370354353233, + 0.24864990162099362, 0.2496455019084883, 0.24999998132131074, 0.24963836089792163, + 0.24882032311579313, 0.2490425997813645, 0.24864262256994435, 0.24962396682576615, + 0.24923850229014768, 0.24758302979788205, 0.2497364221125665, 0.24833712214381948, + 0.24945237443576807, 0.2487672378102696, 0.24872684754912008, 0.24999957029328843, + 0.24932200515812822, 0.24997350874499308, 0.244153258352151, 0.2470095569300082, + 0.2495218586932054, 0.24811062315859542, 0.249937451895014, 0.24908117922763642, + 0.2469981453066533, 0.24887643643724272, 0.24388574974855176, 0.2497890295896134, + 0.24985753528646346, 0.24695464048531734, 0.2494344204235447, 0.24945925714694822, + 0.24932962141151696, 0.24711291849801428, 0.24792873650284003, 0.24901637934722654, + 0.24853055677109265, 0.2493702155100004, 0.24875396650445886, 0.24922420532388534, + 0.24308868874709608, 0.24927788200579612, 0.2495128512715364, 0.2499941533196743, + 0.24750947275677102, 0.24642116577310766, 0.2486200299475723, 0.24850849338968475, + 0.24729748244118202, 0.24744445799632175, 0.24628385050522522, 0.2497364221125665, + 0.24749963780305764, 0.2463077044861883, 0.24740549843131626, 0.24976137204060936, + 0.24818886413733424, 0.24637984563882515, 0.24900312437125263, 0.24725246251464808, + 0.2487629429430977, 0.24865171031370353, 0.24898122102519538, 0.24717579119456606, + 0.2460980895111744, 0.24863419281317758, 0.247722109456114, 0.249483330285878, + 0.24833150510043103, 0.24597530718656643, 0.24808523436616853, 0.24868160690258487, + 0.24928092765244053, 0.24371474005808394, 0.2433465092843804, 0.24605554602979549, + 0.24756485929873812, 0.24528840047309738, 0.24672264761923993, 0.24694122129988064, + 0.24734547447905952, 0.24790888678135556, 0.24858706085548732, 0.24810764262100565, + 0.24863482616236307, 0.24827880810787284, 0.24705502393438042, 0.24732810186246684, + 0.24867054596532764, 0.2487232960074739, 0.24756546852978628, 0.24464019605672277, + 0.24381663076095264, 0.24833150510043103, 0.24818162318958523, 0.24637168326866488, + 0.24833150510043103, 0.249349601679655, 0.24753319754810257, 0.24774444170583876, + 0.24996939860971737, 0.24904696608261323, 0.2485561286364399, 0.24710085120475833, + 0.24909198676990582, 0.24637328290411886, 0.2487969512643165, 0.2462161750985027, + 0.24796208770766182, 0.2482717242774289, 0.24927296051372536, 0.24886483945735274, + 0.24993221481965383, 0.24967300752206745, 0.24992406337114595, 0.24940315381994102, + 0.24994552914860874, 0.24916401420686277, 0.24995963968339815, 0.24973233672772352, + 0.2498543255582006, 0.24993325688497398, 0.24974525038828596, 0.24989546071480015, + 0.2488904029901035, 0.24996473851724282, 0.24968939120009442, 0.24998514600584043, + 0.249970842490889, 0.24994173055054456, 0.24971192503658038, 0.24996592891614833, + 0.24962101269911038, 0.24936160593652545, 0.2491775995760193, 0.24927805303451178, + 0.24956603508121747, 0.24987730415488038, 0.24981789385147107, 0.24999930237271728, + 0.24998381258446437, 0.24986548776192077, 0.24999570010002634, 0.24999591182398645, + 0.24954087006361633, 0.24910219683524124, 0.24994978142288657, 0.24984413117714582, + 0.2499918867206099, 0.24999323357114836, 0.24964894224451797, 0.24992220590291733, + 0.2499341504594762, 0.2499813141457409, 0.24969535176901464, 0.2498639405365019, + 0.24954381392197553, 0.24998693285959547, 0.24991793303705503, 0.24997894264010345, + 0.24987574787239242, 0.24976012810712447, 0.24995513300895975, 0.24999947057326638, + 0.2493918776687613, 0.24924872331895642, 0.24934183397707277, 0.24998712583879595, + 0.24955837700116507, 0.2498331267576389, 0.24999981660905926, 0.24990076371927936, + 0.24953797044297324, 0.2494265251092823, 0.2499977970977972, 0.24982068183227324, + 0.24798657706157284, 0.24991157204598433, 0.249932740378162, 0.24988056441711978, + 0.24976599942821817, 0.24941195957117365, 0.24999632523994095, 0.24974702472240795, + 0.24897521663897507, 0.24999250490090896, 0.24996305238911065, 0.2499888335767531, + 0.24891106929086498, 0.24951603791242954, 0.24695890528687767, 0.24995848440320612, + 0.24981122350168533, 0.2499494301459412, 0.24956512364416267, 0.2499975951279776, + 0.2498004863662123, 0.24998230392115176, 0.24978402532449248, 0.2499982957455684, + 0.2499233623408743, 0.24987574787239242, 0.24992258224573397, 0.2499877309539672, + 0.24999575770242835, 0.2499556931106015, 0.24994451980664456, 0.2499917805496643, + 0.2499960409062574, 0.2488906858411297, 0.24927221432053384, 0.24959263428325498, + 0.2496480416964974, 0.2490800126913869, 0.24700417829410917, 0.24947297522916015, + 0.24903554141985265, 0.24986212683419776, 0.2494710235394941, 0.24812395570517215, + 0.24933303959179373, 0.24963609512126977, 0.2499981104347387, 0.2471864446636986, + 0.24611234486702, 0.2473548098091017, 0.24854393204705955, 0.24653404287294756, + 0.24846732083974551, 0.24799207032097062, 0.24806934707439765, 0.24977920077440272, + 0.24747851578074453, 0.2491977800330258, 0.24899684819566759, 0.2482703871176769, + 0.24861888752204364, 0.24868382859250693, 0.2494854381282655, 0.24934135158120313, + 0.24721040655966006, 0.24893513851554547, 0.24790515403657404, 0.2494710235394941, + 0.2491904936736734, 0.24801381586026497, 0.2494710235394941, 0.2498798518999066, + 0.24883828402114708, 0.24882333106825177, 0.2496579340850042, 0.24987747585242445, + 0.2473671119886922, 0.24776221560479675, 0.2491701169747418, 0.24876589576359523, + 0.24967670262514718, 0.24837643713184637, 0.24908976954024437, 0.22109653049902084, + 0.22548404062580826, 0.2200516152327322, 0.23608366774470158, 0.22444171845768218, + 0.23365223802969015, 0.22447052628573314, 0.24346374266820078, 0.2263252080678485, + 0.23714899298172334, 0.2427377880302641, 0.23030753471526916, 0.23596746213525796, + 0.22891052329406103, 0.23607820782805589, 0.2245286616708374, 0.2319875048139398, + 0.2370609309083055, 0.22732603437985477, 0.23769735050651258, 0.2249363649166962, + 0.2316955228651909, 0.22550619621604487, 0.23145195006251598, 0.22866901371862267, + 0.22540992944493907, 0.22272184895874236, 0.21868396394543543, 0.2288664303844395, + 0.2387563329217531, 0.23851371439846614, 0.2396548175328528, 0.235265249196123, + 0.22619919934323457, 0.2334753125319801, 0.22747857709673994, 0.22237888950209675, + 0.2293405393777329, 0.23512174304783892, 0.23605961692837263, 0.2363884920872055, + 0.22911784265366114, 0.23508631539486008, 0.24298537417416582, 0.23496693141570144, + 0.2353505896619777, 0.23423355787376254, 0.23026861594819384, 0.24205955937189744, + 0.23444214017181403, 0.20711228404659351, 0.22376613724792477, 0.20579434619205228, + 0.2194026188338758, 0.2106294523516182, 0.19834879650946866, 0.23482430646847627, + 0.20760503194119267, 0.2151216557564706, 0.197085900182565, 0.21569731990066776, + 0.217758419139472, 0.21004598240061953, 0.22350675479473048, 0.21633307169842297, + 0.21137335744670882, 0.2177596292210227, 0.19501204554661022, 0.19284473370296384, + 0.22784579330881613, 0.20506242624967894, 0.22457929209637387, 0.19874826014581815, + 0.22121053191398565, 0.2104514538390277, 0.2094343294766834, 0.22234971134381387, + 0.22296768210934845, 0.21382831666860633, 0.21326665887229082, 0.20547154049147015, + 0.1972613998245205, 0.21223025778509932, 0.22498680501540697, 0.22691830656416243, + 0.19521869814219142, 0.20988589092182444, 0.21868921978115488, 0.2240953476388927, + 0.20928467448290794, 0.20578906266338623, 0.20681378450968757, 0.22376613724792477, + 0.20553847889431168, 0.20376157669628492, 0.20862506432378858, 0.2195097946800901, + 0.21546464912779167, 0.2130691837611387, 0.22424979277256304}, new int[] {150, 3}, 'f'); INDArray expCUDA = Nd4j.create(new double[] {-0.1397797281402293, 0.262442968158004, 0.11257253487714672, - -0.14204931755613565, 0.26459585187861423, 0.11326958823058592, -0.14169128233548328, - 0.26498290860797713, 0.11363791196902154, -0.14232126667905204, 0.2653795480791151, - 0.11377287243047099, -0.13953189423305898, 0.2625621860805628, 0.11274888391033339, - -0.13726747097370906, 0.26043568605313927, 0.1110259706999667, -0.14119986101271959, - 0.26517763983630427, 0.11360221352588595, -0.14053290451163764, 0.2630865243565184, - 0.11278706577163364, -0.14309744661189566, 0.26650186027631995, 0.11428980254509002, - -0.14181125575709072, 0.2638849706413404, 0.11325345211563415, -0.13824683928327508, - 0.2602840431545141, 0.11167166360958086, -0.14102724341299233, 0.2638192134517527, - 0.11316217164895999, -0.14221044613799594, 0.2646000994613115, 0.11355782124995259, - -0.1428642360983056, 0.26665431757043373, 0.11454611162697295, -0.13493373555886776, - 0.25723700689792417, 0.11066871266027851, -0.13274473377543702, 0.25693570312125485, - 0.11004518408130244, -0.1365899988385908, 0.260775617522195, 0.11133859613971274, - -0.1397225928004509, 0.26290565902532126, 0.11243264263783195, -0.1371867315730828, - 0.2588103442915592, 0.11043327812855466, -0.13834625792794394, 0.2618474094769191, - 0.11221118251826752, -0.13991940132336245, 0.26117123507760176, 0.11167825524041167, - -0.13879578742895515, 0.2626615816962663, 0.11209734783562991, -0.13991412321056712, - 0.26461990802358687, 0.11378368217808009, -0.14082620714516011, 0.26400443437557636, - 0.111965718194097, -0.14128496857231843, 0.2635459447146433, 0.11298115125486892, - -0.1419966745979669, 0.26403632111095915, 0.11292424586380322, -0.14055553461452114, - 0.26384362761416763, 0.11243563386028677, -0.13968123293840568, 0.2619131683521957, - 0.11227050868947011, -0.14001305190002286, 0.2623219326079562, 0.11238822036193419, - -0.141911120074517, 0.26470575692604925, 0.11346951493365338, -0.1420437953574474, - 0.26455829651364116, 0.11331249801989905, -0.1395947537242465, 0.2622953617320538, - 0.11144093434955048, -0.13656997236245197, 0.2590949716288484, 0.11210367280536893, - -0.13481743979284383, 0.25776322971796567, 0.11122948174103235, -0.14181125575709072, - 0.2638849706413404, 0.11325345211563415, -0.14103682300076956, 0.2639123513628277, - 0.1130743968031554, -0.13876313113599886, 0.26072016513363033, 0.11165922212607637, - -0.14181125575709072, 0.2638849706413404, 0.11325345211563415, -0.14281547473615197, - 0.2664381301793583, 0.11428866752101949, -0.14032871539338249, 0.26266338471441153, - 0.11255798512378257, -0.13981966971765328, 0.26341655887464027, 0.11273795514065381, - -0.14388057567732068, 0.26714789047716925, 0.11440730710165808, -0.14223211956518317, - 0.2660736178596304, 0.11418899137369003, -0.14001004113201757, 0.2643822169708257, - 0.11201250285527188, -0.13883802483037633, 0.2619899769262556, 0.11175309451997711, - -0.1422205386545011, 0.2653028226880685, 0.11338102131495005, -0.13857253148598467, - 0.2612501894958287, 0.11229027994882086, -0.1419483670463606, 0.2652619372220056, - 0.11377674967870428, -0.13848229437537088, 0.2607602194371945, 0.11192438494521152, - -0.14083577467674055, 0.26346078628006814, 0.11290025765751623, 0.08820321741221084, - -0.042962937132769455, 0.036456111185867196, 0.08768912912215979, -0.04147520913561469, - 0.03800308958007328, 0.08849157340423255, -0.042869041687349965, 0.03640514758784335, - 0.08840222986126425, -0.03926208009247603, 0.04148233537457793, 0.08862602509961467, - -0.04179046555496778, 0.037843699525301824, 0.0885159045500426, -0.04029524056756361, - 0.04038791773259154, 0.0875052763451695, -0.041337546434876894, 0.03786887705583882, - 0.08788518291446613, -0.037679310772829086, 0.043742570040689446, 0.08883444543172667, - -0.04226276451085631, 0.037935789330510686, 0.08772309407654977, -0.038425579983097494, - 0.041939804213313996, 0.08808908456289374, -0.03799921404053968, 0.04358666800484748, - 0.08766472994380456, -0.04014660522142602, 0.03963355543195826, 0.0891685847336339, - -0.04080973046501342, 0.04077905573678634, 0.08859320520357397, -0.041209006169318566, - 0.03898071800741838, 0.08745416827005757, -0.03918013131062083, 0.04122845129298612, - 0.08802355573068858, -0.042103933343329215, 0.03752951602609446, 0.08789456036870592, - -0.039809397229546725, 0.040190830583142705, 0.08878158132891725, -0.04051137632979936, - 0.04096511133924192, 0.0889868438845658, -0.04098834442211815, 0.038992891303455214, - 0.08855010208484065, -0.03972297100409324, 0.0415308568061289, 0.0874194387267, - -0.04032259594136955, 0.03849601262835472, 0.08820726056619942, -0.04063536986928261, - 0.03972819093163967, 0.08915014368639582, -0.04165853399771313, 0.038287425905390665, - 0.08903747908029147, -0.041486958695521756, 0.03940023963411198, 0.08844748155900393, - -0.041555467710842814, 0.03868439107248724, 0.08823209789313106, -0.04191850288429148, - 0.03784066268725205, 0.089106470980532, -0.042740616548806856, 0.037094874798938124, - 0.08840698224388845, -0.04230890789862867, 0.03648221028869615, 0.08819168460920213, - -0.04065217650480662, 0.03919793487055319, 0.08832897727363223, -0.03968098276580225, - 0.0416667028390964, 0.08848272560584433, -0.03938587160340448, 0.04188955034091001, - 0.08854564025617767, -0.03942970016629068, 0.042104058952174755, 0.08830426865151225, - -0.040033880587860324, 0.04078181954110782, 0.08882030708521758, -0.04108360797322201, - 0.03864283772967878, 0.08781996871300206, -0.039376157124858285, 0.04070276372274433, - 0.08697060313120034, -0.040530214675006664, 0.038768867596498016, 0.08821150053622706, - -0.04227812405783864, 0.037096163362112966, 0.08920615348763587, -0.04143582234449487, - 0.03914792098507049, 0.08781612348104591, -0.03969271460836636, 0.040829736500267014, - 0.08829027306019399, -0.03930630213110146, 0.04138724809469049, 0.08863573847454138, - -0.039879877499865955, 0.04122260831236561, 0.0883335013507428, -0.041109045287316716, - 0.039008466274951054, 0.08850954251831919, -0.040127040514003495, 0.040758394091767146, - 0.08799737345705212, -0.0378824673311011, 0.04356830692232184, 0.08832089274747237, - -0.03976254339418301, 0.040901381865641434, 0.08812016738529858, -0.04014173593635116, - 0.040677302155068665, 0.08811124331057293, -0.039999358112224784, 0.04055527566668088, - 0.08838773492102094, -0.04114771748741527, 0.03920462961422201, 0.08757388372185691, - -0.037704526819131993, 0.04331206318950709, 0.0881576331615599, -0.03988942301339941, - 0.04067380373044536, 0.013495004596650092, -0.10467787904526984, -0.06924598498509513, - 0.013732488601800671, -0.10359958526920321, -0.073867622338269, 0.014663507273385447, - -0.10681226123870459, -0.06964113429219437, 0.014418259088360003, -0.10540559541359698, - -0.07329534366412284, 0.014081092360166813, -0.10538099101250491, -0.07055881966477318, - 0.015447314035613191, -0.10838784129437379, -0.06783586065961766, 0.013085578431350012, - -0.10107418630601421, -0.07612433531982664, 0.01553720555749748, -0.10797911451459238, - -0.07066651886657471, 0.014997053687435705, -0.10641485321579136, -0.07222340561309375, - 0.013865261741969313, -0.10650075751537919, -0.06693341363771006, 0.013766354176025222, - -0.10499674891965531, -0.07217795404205836, 0.014260160255118556, -0.10513678167003707, - -0.07264441501434046, 0.01420865307474977, -0.10584712083654362, -0.07062826312502943, - 0.013561444762186578, -0.10295250306644092, -0.07351315002254191, 0.013027918843870464, - -0.10261633218277293, -0.07130546452883366, 0.013426013289009173, -0.10454045824088537, - -0.0705867845954161, 0.014432368908178261, -0.10569362827120234, -0.07298426151600021, - 0.014858509648913935, -0.10801642563076536, -0.06707535623461379, 0.01563198862025337, - -0.10867604725646529, -0.06591468875118317, 0.014507371715046171, -0.10451467522071968, - -0.07532563977777654, 0.013994233557191597, -0.10592405632576582, -0.06912805117416723, - 0.013298523016463842, -0.10278349861729164, -0.0738409688404247, 0.015833152505383894, - -0.1088639069394899, -0.06806853577990854, 0.01407299710172178, -0.10468720551145808, - -0.07357408848277809, 0.0140921601711803, -0.1058209059211477, -0.07084032565658337, - 0.015094038913090283, -0.1073532833427305, -0.07120135240882869, 0.013890700619125151, - -0.10438742115148218, -0.07384287774382131, 0.013780213251619126, -0.10429428867892578, - -0.07404967280508117, 0.014131634326137085, -0.10510768374389001, -0.07140704509303743, - 0.015362427285654635, -0.10744618787519382, -0.07242981001774759, 0.01538546151471559, - -0.10786708970599292, -0.06990741917330204, 0.015041211700757331, -0.10805549163241167, - -0.06803553703700807, 0.013996256799870863, -0.10492288857316887, -0.07083972134954941, - 0.014547662409359825, -0.10531942691720375, -0.07503719765162918, 0.014926121528476222, - -0.10557934559199808, -0.07556161565121688, 0.014876220620969672, -0.10779446920555454, - -0.06663943676230895, 0.013299175512052633, -0.1044884887849407, -0.07012365270229738, - 0.014310962748405539, -0.10548746091961465, -0.07322203418066323, 0.013650673557163585, - -0.10398724057331998, -0.07427001885442606, 0.014138340887381534, -0.10592565377607648, - -0.07048866647966796, 0.013731535938664729, -0.10526565567088782, -0.0690572199450989, - 0.013648640373434837, -0.10533812001977037, -0.0694939741135303, 0.013732488601800671, - -0.10359958526920321, -0.073867622338269, 0.01407159219681424, -0.10593041742703585, - -0.06924501967896152, 0.013525129270068457, -0.10521593889975128, -0.06845021944709595, - 0.013666043658741503, -0.10503231289490245, -0.06987960468127483, 0.01409981353709936, - -0.1045716285237493, -0.07292719015185552, 0.013960146652089741, -0.10510748952535, - -0.0720500850059039, 0.01324381306751248, -0.10425347510224883, -0.07104713730722463, - 0.013781373376598662, -0.10407691618897837, -0.07430592437883553}, new int[] {150, 3}, 'c'); + -0.14204931755613565, 0.26459585187861423, 0.11326958823058592, -0.14169128233548328, + 0.26498290860797713, 0.11363791196902154, -0.14232126667905204, 0.2653795480791151, + 0.11377287243047099, -0.13953189423305898, 0.2625621860805628, 0.11274888391033339, + -0.13726747097370906, 0.26043568605313927, 0.1110259706999667, -0.14119986101271959, + 0.26517763983630427, 0.11360221352588595, -0.14053290451163764, 0.2630865243565184, + 0.11278706577163364, -0.14309744661189566, 0.26650186027631995, 0.11428980254509002, + -0.14181125575709072, 0.2638849706413404, 0.11325345211563415, -0.13824683928327508, + 0.2602840431545141, 0.11167166360958086, -0.14102724341299233, 0.2638192134517527, + 0.11316217164895999, -0.14221044613799594, 0.2646000994613115, 0.11355782124995259, + -0.1428642360983056, 0.26665431757043373, 0.11454611162697295, -0.13493373555886776, + 0.25723700689792417, 0.11066871266027851, -0.13274473377543702, 0.25693570312125485, + 0.11004518408130244, -0.1365899988385908, 0.260775617522195, 0.11133859613971274, + -0.1397225928004509, 0.26290565902532126, 0.11243264263783195, -0.1371867315730828, + 0.2588103442915592, 0.11043327812855466, -0.13834625792794394, 0.2618474094769191, + 0.11221118251826752, -0.13991940132336245, 0.26117123507760176, 0.11167825524041167, + -0.13879578742895515, 0.2626615816962663, 0.11209734783562991, -0.13991412321056712, + 0.26461990802358687, 0.11378368217808009, -0.14082620714516011, 0.26400443437557636, + 0.111965718194097, -0.14128496857231843, 0.2635459447146433, 0.11298115125486892, + -0.1419966745979669, 0.26403632111095915, 0.11292424586380322, -0.14055553461452114, + 0.26384362761416763, 0.11243563386028677, -0.13968123293840568, 0.2619131683521957, + 0.11227050868947011, -0.14001305190002286, 0.2623219326079562, 0.11238822036193419, + -0.141911120074517, 0.26470575692604925, 0.11346951493365338, -0.1420437953574474, + 0.26455829651364116, 0.11331249801989905, -0.1395947537242465, 0.2622953617320538, + 0.11144093434955048, -0.13656997236245197, 0.2590949716288484, 0.11210367280536893, + -0.13481743979284383, 0.25776322971796567, 0.11122948174103235, -0.14181125575709072, + 0.2638849706413404, 0.11325345211563415, -0.14103682300076956, 0.2639123513628277, + 0.1130743968031554, -0.13876313113599886, 0.26072016513363033, 0.11165922212607637, + -0.14181125575709072, 0.2638849706413404, 0.11325345211563415, -0.14281547473615197, + 0.2664381301793583, 0.11428866752101949, -0.14032871539338249, 0.26266338471441153, + 0.11255798512378257, -0.13981966971765328, 0.26341655887464027, 0.11273795514065381, + -0.14388057567732068, 0.26714789047716925, 0.11440730710165808, -0.14223211956518317, + 0.2660736178596304, 0.11418899137369003, -0.14001004113201757, 0.2643822169708257, + 0.11201250285527188, -0.13883802483037633, 0.2619899769262556, 0.11175309451997711, + -0.1422205386545011, 0.2653028226880685, 0.11338102131495005, -0.13857253148598467, + 0.2612501894958287, 0.11229027994882086, -0.1419483670463606, 0.2652619372220056, + 0.11377674967870428, -0.13848229437537088, 0.2607602194371945, 0.11192438494521152, + -0.14083577467674055, 0.26346078628006814, 0.11290025765751623, 0.08820321741221084, + -0.042962937132769455, 0.036456111185867196, 0.08768912912215979, -0.04147520913561469, + 0.03800308958007328, 0.08849157340423255, -0.042869041687349965, 0.03640514758784335, + 0.08840222986126425, -0.03926208009247603, 0.04148233537457793, 0.08862602509961467, + -0.04179046555496778, 0.037843699525301824, 0.0885159045500426, -0.04029524056756361, + 0.04038791773259154, 0.0875052763451695, -0.041337546434876894, 0.03786887705583882, + 0.08788518291446613, -0.037679310772829086, 0.043742570040689446, 0.08883444543172667, + -0.04226276451085631, 0.037935789330510686, 0.08772309407654977, -0.038425579983097494, + 0.041939804213313996, 0.08808908456289374, -0.03799921404053968, 0.04358666800484748, + 0.08766472994380456, -0.04014660522142602, 0.03963355543195826, 0.0891685847336339, + -0.04080973046501342, 0.04077905573678634, 0.08859320520357397, -0.041209006169318566, + 0.03898071800741838, 0.08745416827005757, -0.03918013131062083, 0.04122845129298612, + 0.08802355573068858, -0.042103933343329215, 0.03752951602609446, 0.08789456036870592, + -0.039809397229546725, 0.040190830583142705, 0.08878158132891725, -0.04051137632979936, + 0.04096511133924192, 0.0889868438845658, -0.04098834442211815, 0.038992891303455214, + 0.08855010208484065, -0.03972297100409324, 0.0415308568061289, 0.0874194387267, + -0.04032259594136955, 0.03849601262835472, 0.08820726056619942, -0.04063536986928261, + 0.03972819093163967, 0.08915014368639582, -0.04165853399771313, 0.038287425905390665, + 0.08903747908029147, -0.041486958695521756, 0.03940023963411198, 0.08844748155900393, + -0.041555467710842814, 0.03868439107248724, 0.08823209789313106, -0.04191850288429148, + 0.03784066268725205, 0.089106470980532, -0.042740616548806856, 0.037094874798938124, + 0.08840698224388845, -0.04230890789862867, 0.03648221028869615, 0.08819168460920213, + -0.04065217650480662, 0.03919793487055319, 0.08832897727363223, -0.03968098276580225, + 0.0416667028390964, 0.08848272560584433, -0.03938587160340448, 0.04188955034091001, + 0.08854564025617767, -0.03942970016629068, 0.042104058952174755, 0.08830426865151225, + -0.040033880587860324, 0.04078181954110782, 0.08882030708521758, -0.04108360797322201, + 0.03864283772967878, 0.08781996871300206, -0.039376157124858285, 0.04070276372274433, + 0.08697060313120034, -0.040530214675006664, 0.038768867596498016, 0.08821150053622706, + -0.04227812405783864, 0.037096163362112966, 0.08920615348763587, -0.04143582234449487, + 0.03914792098507049, 0.08781612348104591, -0.03969271460836636, 0.040829736500267014, + 0.08829027306019399, -0.03930630213110146, 0.04138724809469049, 0.08863573847454138, + -0.039879877499865955, 0.04122260831236561, 0.0883335013507428, -0.041109045287316716, + 0.039008466274951054, 0.08850954251831919, -0.040127040514003495, 0.040758394091767146, + 0.08799737345705212, -0.0378824673311011, 0.04356830692232184, 0.08832089274747237, + -0.03976254339418301, 0.040901381865641434, 0.08812016738529858, -0.04014173593635116, + 0.040677302155068665, 0.08811124331057293, -0.039999358112224784, 0.04055527566668088, + 0.08838773492102094, -0.04114771748741527, 0.03920462961422201, 0.08757388372185691, + -0.037704526819131993, 0.04331206318950709, 0.0881576331615599, -0.03988942301339941, + 0.04067380373044536, 0.013495004596650092, -0.10467787904526984, -0.06924598498509513, + 0.013732488601800671, -0.10359958526920321, -0.073867622338269, 0.014663507273385447, + -0.10681226123870459, -0.06964113429219437, 0.014418259088360003, -0.10540559541359698, + -0.07329534366412284, 0.014081092360166813, -0.10538099101250491, -0.07055881966477318, + 0.015447314035613191, -0.10838784129437379, -0.06783586065961766, 0.013085578431350012, + -0.10107418630601421, -0.07612433531982664, 0.01553720555749748, -0.10797911451459238, + -0.07066651886657471, 0.014997053687435705, -0.10641485321579136, -0.07222340561309375, + 0.013865261741969313, -0.10650075751537919, -0.06693341363771006, 0.013766354176025222, + -0.10499674891965531, -0.07217795404205836, 0.014260160255118556, -0.10513678167003707, + -0.07264441501434046, 0.01420865307474977, -0.10584712083654362, -0.07062826312502943, + 0.013561444762186578, -0.10295250306644092, -0.07351315002254191, 0.013027918843870464, + -0.10261633218277293, -0.07130546452883366, 0.013426013289009173, -0.10454045824088537, + -0.0705867845954161, 0.014432368908178261, -0.10569362827120234, -0.07298426151600021, + 0.014858509648913935, -0.10801642563076536, -0.06707535623461379, 0.01563198862025337, + -0.10867604725646529, -0.06591468875118317, 0.014507371715046171, -0.10451467522071968, + -0.07532563977777654, 0.013994233557191597, -0.10592405632576582, -0.06912805117416723, + 0.013298523016463842, -0.10278349861729164, -0.0738409688404247, 0.015833152505383894, + -0.1088639069394899, -0.06806853577990854, 0.01407299710172178, -0.10468720551145808, + -0.07357408848277809, 0.0140921601711803, -0.1058209059211477, -0.07084032565658337, + 0.015094038913090283, -0.1073532833427305, -0.07120135240882869, 0.013890700619125151, + -0.10438742115148218, -0.07384287774382131, 0.013780213251619126, -0.10429428867892578, + -0.07404967280508117, 0.014131634326137085, -0.10510768374389001, -0.07140704509303743, + 0.015362427285654635, -0.10744618787519382, -0.07242981001774759, 0.01538546151471559, + -0.10786708970599292, -0.06990741917330204, 0.015041211700757331, -0.10805549163241167, + -0.06803553703700807, 0.013996256799870863, -0.10492288857316887, -0.07083972134954941, + 0.014547662409359825, -0.10531942691720375, -0.07503719765162918, 0.014926121528476222, + -0.10557934559199808, -0.07556161565121688, 0.014876220620969672, -0.10779446920555454, + -0.06663943676230895, 0.013299175512052633, -0.1044884887849407, -0.07012365270229738, + 0.014310962748405539, -0.10548746091961465, -0.07322203418066323, 0.013650673557163585, + -0.10398724057331998, -0.07427001885442606, 0.014138340887381534, -0.10592565377607648, + -0.07048866647966796, 0.013731535938664729, -0.10526565567088782, -0.0690572199450989, + 0.013648640373434837, -0.10533812001977037, -0.0694939741135303, 0.013732488601800671, + -0.10359958526920321, -0.073867622338269, 0.01407159219681424, -0.10593041742703585, + -0.06924501967896152, 0.013525129270068457, -0.10521593889975128, -0.06845021944709595, + 0.013666043658741503, -0.10503231289490245, -0.06987960468127483, 0.01409981353709936, + -0.1045716285237493, -0.07292719015185552, 0.013960146652089741, -0.10510748952535, + -0.0720500850059039, 0.01324381306751248, -0.10425347510224883, -0.07104713730722463, + 0.013781373376598662, -0.10407691618897837, -0.07430592437883553}, new int[] {150, 3}, 'c'); INDArray res = x.muli(y); @@ -1255,7 +1323,7 @@ public class RandomTests extends BaseNd4jTest { @Test @Disabled - public void testTruncatedNormal1() { + public void testTruncatedNormal1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); INDArray z01 = Nd4j.create(10000000).assign(-119119d); @@ -1281,7 +1349,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testLogNormal1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLogNormal1(Nd4jBackend backend) { Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); INDArray z01 = Nd4j.create(1000000); @@ -1307,7 +1377,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testLinspace2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLinspace2(Nd4jBackend backend) { INDArray res = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray exp = Nd4j.create(new double[] {1, 2, 3, 4, 5}); @@ -1316,24 +1388,32 @@ public class RandomTests extends BaseNd4jTest { @Test - public void testOrthogonalDistribution1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOrthogonalDistribution1(Nd4jBackend backend) { val dist = new OrthogonalDistribution(1.0); val array = dist.sample(new int[] {6, 9}); } @Test - public void testOrthogonalDistribution2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOrthogonalDistribution2(Nd4jBackend backend) { val dist = new OrthogonalDistribution(1.0); val array = dist.sample(new int[] {9, 6}); } @Test - public void testOrthogonalDistribution3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOrthogonalDistribution3(Nd4jBackend backend) { val dist = new OrthogonalDistribution(1.0); val array = dist.sample(new int[] {9, 9}); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void reproducabilityTest(){ int numBatches = 1; @@ -1350,7 +1430,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testJavaInt_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testJavaInt_1(Nd4jBackend backend) { for (int e = 0; e < 100000; e++) { val i = Nd4j.getRandom().nextInt(10, 20); @@ -1359,6 +1441,8 @@ public class RandomTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBernoulli(){ Nd4j.getRandom().setSeed(12345); INDArray arr = Nd4j.create(DataType.DOUBLE, 100); @@ -1380,6 +1464,8 @@ public class RandomTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRngRepeatabilityUniform(){ val nexp = Nd4j.create(DataType.FLOAT, 10); Nd4j.getRandom().setSeed(12345); @@ -1395,6 +1481,8 @@ public class RandomTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRngRepeatabilityBernoulli(){ Nd4j.getRandom().setSeed(12345); INDArray out1 = Nd4j.create(DataType.FLOAT, 10); @@ -1408,6 +1496,8 @@ public class RandomTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testGamma(){ Nd4j.getRandom().setSeed(12345); INDArray shape = Nd4j.createFromArray(new int[] {1000,1000}); @@ -1429,6 +1519,8 @@ public class RandomTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testPoisson(){ Nd4j.getRandom().setSeed(12345); INDArray shape = Nd4j.createFromArray(new int[] {1,3}); @@ -1442,6 +1534,8 @@ public class RandomTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testShuffle(){ Nd4j.getRandom().setSeed(12345); INDArray alpha = Nd4j.rand(1,3); @@ -1454,7 +1548,9 @@ public class RandomTests extends BaseNd4jTest { } @Test - public void testRandom() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRandom(Nd4jBackend backend) { val r1 = new java.util.Random(119); val r2 = Nd4j.getRandom(); r2.setSeed(119); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java index 5f74d8be4..715548c36 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java @@ -27,10 +27,11 @@ import lombok.Builder; import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; -import org.nd4j.OpValidationSuite; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.common.base.Preconditions; import org.nd4j.common.util.ArrayUtil; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; @@ -63,11 +64,8 @@ import java.util.List; import java.util.Map; @Slf4j -public class RngValidationTests extends BaseNd4jTest { +public class RngValidationTests extends BaseNd4jTestWithBackends { - public RngValidationTests(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -124,9 +122,9 @@ public class RngValidationTests extends BaseNd4jTest { @Test - public void validateRngDistributions(){ - OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6958 - 2018-01-09 - + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void validateRngDistributions(Nd4jBackend backend){ List testCases = new ArrayList<>(); for(DataType type : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { //Legacy (non-custom) RNG ops: @@ -154,8 +152,8 @@ public class RngValidationTests extends BaseNd4jTest { testCases.add(TestCase.builder().opType("binomial").dataType(type).shape(100,10000).minValue(0).maxValue(20).minValueInclusive(true).maxValueInclusive(true).arg("n", 20).arg("p",0.2) .expectedMean(20*0.2).expectedStd(Math.sqrt(20*0.2*(1-0.2)) /*var = np(1-p)*/).meanRelativeErrorTolerance(0.001).stdRelativeErrorTolerance(0.01).build()); - //truncated normal clips at (mean-2*std, mean+2*std). Mean for equal 2 sided clipping about mean is same as original mean. Variance is difficult to calculate... - //Assume variance is similar to non-truncated normal (should be a bit less in practice) but use large relative error here + //truncated normal clips at (mean-2*std, mean+2*std). Mean for equal 2 sided clipping about mean is same as original mean. Variance is difficult to calculate... + //Assume variance is similar to non-truncated normal (should be a bit less in practice) but use large relative error here testCases.add(TestCase.builder().opType("truncated_normal").dataType(type).shape(new long[0]).minValue(-2.0).maxValue(2.0).minValueInclusive(true).maxValueInclusive(true).arg("mean", 0.0).arg("std", 1.0).build()); //Don't check mean/std for 1 element testCases.add(TestCase.builder().opType("truncated_normal").dataType(type).shape(1000).minValue(-2.0).maxValue(2.0).minValueInclusive(true).maxValueInclusive(true).arg("mean", 0.0).arg("std", 1.0) .expectedMean(0.0).expectedStd(1.0).stdRelativeErrorTolerance(0.2).meanMinAbsErrorTolerance(0.1).build()); @@ -350,16 +348,16 @@ public class RngValidationTests extends BaseNd4jTest { } private static double minValue(DataType dataType){ - switch (dataType){ - case DOUBLE: - return -Double.MAX_VALUE; - case FLOAT: - return -Float.MAX_VALUE; - case HALF: - return -65504.0; - default: - throw new RuntimeException("Dtype not supported: " + dataType); - } + switch (dataType){ + case DOUBLE: + return -Double.MAX_VALUE; + case FLOAT: + return -Float.MAX_VALUE; + case HALF: + return -65504.0; + default: + throw new RuntimeException("Dtype not supported: " + dataType); + } } private static double maxValue(DataType dataType){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/schedule/TestSchedules.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/schedule/TestSchedules.java index 761380274..38d086d4f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/schedule/TestSchedules.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/schedule/TestSchedules.java @@ -21,7 +21,9 @@ package org.nd4j.linalg.schedule; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.MapperFeature; @@ -30,18 +32,17 @@ import org.nd4j.shade.jackson.databind.SerializationFeature; import static org.junit.jupiter.api.Assertions.assertEquals; -public class TestSchedules extends BaseNd4jTest { +public class TestSchedules extends BaseNd4jTestWithBackends { - public TestSchedules(Nd4jBackend b){ - super(b); - } @Override - public char ordering(){ + public char ordering() { return 'c'; } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testJson() throws Exception { ObjectMapper om = new ObjectMapper(); @@ -69,7 +70,9 @@ public class TestSchedules extends BaseNd4jTest { } @Test - public void testScheduleValues(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScheduleValues(Nd4jBackend backend) { double lr = 0.8; double decay = 0.9; @@ -120,7 +123,9 @@ public class TestSchedules extends BaseNd4jTest { } @Test - public void testMapSchedule(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMapSchedule(Nd4jBackend backend) { ISchedule schedule = new MapSchedule.Builder(ScheduleType.ITERATION) .add(0, 0.5) @@ -136,7 +141,9 @@ public class TestSchedules extends BaseNd4jTest { } } @Test - public void testCycleSchedule(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCycleSchedule(Nd4jBackend backend) { ISchedule schedule = new CycleSchedule(ScheduleType.ITERATION, 1.5, 100); assertEquals(0.15, schedule.valueAt(0, 0), 1e-6); assertEquals(1.5, schedule.valueAt(45, 0), 1e-6); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/BasicSerDeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/BasicSerDeTests.java index cdbb14bd3..fa4abbd1b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/BasicSerDeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/BasicSerDeTests.java @@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -37,15 +38,11 @@ import java.io.ByteArrayOutputStream; import static junit.framework.TestCase.assertEquals; -@RunWith(Parameterized.class) -@Slf4j -public class BasicSerDeTests extends BaseNd4jTest { - public BasicSerDeTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } - DataType initialType; +@Slf4j +public class BasicSerDeTests extends BaseNd4jTestWithBackends { + + DataType initialType = Nd4j.dataType(); @AfterEach public void after() { @@ -54,7 +51,9 @@ public class BasicSerDeTests extends BaseNd4jTest { @Test - public void testBasicDataTypeSwitch1() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicDataTypeSwitch1(Nd4jBackend backend) throws Exception { DataType initialType = Nd4j.dataType(); Nd4j.setDataType(DataType.FLOAT); @@ -82,7 +81,9 @@ public class BasicSerDeTests extends BaseNd4jTest { } @Test - public void testHalfSerde_1() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHalfSerde_1(Nd4jBackend backend) throws Exception { val array = Nd4j.create(DataType.HALF, 3, 4); array.assign(1.0f); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java index d2e70277c..dfaf5bfe7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java @@ -25,7 +25,9 @@ import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import lombok.val; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -41,11 +43,8 @@ import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; import static org.junit.jupiter.api.Assertions.assertEquals; -public class JsonSerdeTests extends BaseNd4jTest { +public class JsonSerdeTests extends BaseNd4jTestWithBackends { - public JsonSerdeTests(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -54,7 +53,9 @@ public class JsonSerdeTests extends BaseNd4jTest { @Test - public void testNDArrayTextSerializer() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNDArrayTextSerializer(Nd4jBackend backend) throws Exception { for(char order : new char[]{'c', 'f'}) { Nd4j.factory().setOrder(order); for (DataType globalDT : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { @@ -91,7 +92,9 @@ public class JsonSerdeTests extends BaseNd4jTest { @Test - public void testBackwardCompatability() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBackwardCompatability(Nd4jBackend backend) throws Exception { Nd4j.getNDArrayFactory().setOrder('f'); for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java index 706c727f5..63fe8057a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/LargeSerDeTests.java @@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -36,16 +37,15 @@ import java.io.*; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) + @Slf4j @Disabled("AB 2019/05/23 - JVM crash on linux-x86_64-cpu-avx512 - issue #7657") -public class LargeSerDeTests extends BaseNd4jTest { - public LargeSerDeTests(Nd4jBackend backend) { - super(backend); - } +public class LargeSerDeTests extends BaseNd4jTestWithBackends { - @Test - public void testLargeArraySerDe_1() throws Exception { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLargeArraySerDe_1(Nd4jBackend backend) throws Exception { val arrayA = Nd4j.rand(new long[] {1, 135079944}); //val arrayA = Nd4j.rand(new long[] {1, 13507}); @@ -69,7 +69,7 @@ public class LargeSerDeTests extends BaseNd4jTest { @Test @Disabled // this should be commented out, since it requires approx 10GB ram to run - public void testLargeArraySerDe_2() throws Exception { + public void testLargeArraySerDe_2(Nd4jBackend backend) throws Exception { INDArray arrayA = Nd4j.createUninitialized(100000, 12500); log.info("Shape: {}; Length: {}", arrayA.shape(), arrayA.length()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java index f8452e84a..9f8d5d7af 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java @@ -28,7 +28,9 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -42,15 +44,12 @@ import java.util.Map; import static org.junit.jupiter.api.Assertions.*; @Slf4j -public class NumpyFormatTests extends BaseNd4jTest { - - - public NumpyFormatTests(Nd4jBackend backend) { - super(backend); - } +public class NumpyFormatTests extends BaseNd4jTestWithBackends { @Test - public void testToNpyFormat(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToNpyFormat(@TempDir Path testDir,Nd4jBackend backend) throws Exception { val dir = testDir.toFile(); new ClassPathResource("numpy_arrays/").copyDirectory(dir); @@ -99,7 +98,9 @@ public class NumpyFormatTests extends BaseNd4jTest { } @Test - public void testToNpyFormatScalars(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToNpyFormatScalars(@TempDir Path testDir,Nd4jBackend backend) throws Exception { // File dir = new File("C:\\DL4J\\Git\\dl4j-test-resources\\src\\main\\resources\\numpy_arrays\\scalar"); val dir = testDir.toFile(); @@ -153,7 +154,9 @@ public class NumpyFormatTests extends BaseNd4jTest { @Test - public void testNpzReading(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNpzReading(@TempDir Path testDir,Nd4jBackend backend) throws Exception { val dir = testDir.toFile(); new ClassPathResource("numpy_arrays/npz/").copyDirectory(dir); @@ -193,7 +196,9 @@ public class NumpyFormatTests extends BaseNd4jTest { } @Test - public void testTxtReading() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTxtReading(Nd4jBackend backend) throws Exception { File f = new ClassPathResource("numpy_arrays/txt/arange_3,4_float32.txt").getFile(); INDArray arr = Nd4j.readNumpy(DataType.FLOAT, f.getPath()); @@ -212,7 +217,9 @@ public class NumpyFormatTests extends BaseNd4jTest { @Test - public void testNpy(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNpy(@TempDir Path testDir,Nd4jBackend backend) throws Exception { for(boolean empty : new boolean[]{false, true}) { val dir = testDir.toFile(); if(!empty) { @@ -256,13 +263,15 @@ public class NumpyFormatTests extends BaseNd4jTest { } @Test - public void testFromNumpyScalar() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFromNumpyScalar(Nd4jBackend backend) throws Exception { val out = Nd4j.createFromNpyFile(new ClassPathResource("numpy_oneoff/scalar.npy").getFile()); assertEquals(Nd4j.scalar(DataType.INT, 1), out); } @Test() - public void readNumpyCorruptHeader1(@TempDir Path testDir) throws Exception { + public void readNumpyCorruptHeader1(@TempDir Path testDir,Nd4jBackend backend) throws Exception { assertThrows(RuntimeException.class,() -> { File f = testDir.toFile(); @@ -286,7 +295,7 @@ public class NumpyFormatTests extends BaseNd4jTest { } @Test() - public void readNumpyCorruptHeader2(@TempDir Path testDir) throws Exception { + public void readNumpyCorruptHeader2(@TempDir Path testDir,Nd4jBackend backend) throws Exception { assertThrows(RuntimeException.class,() -> { File f = testDir.toFile(); @@ -310,7 +319,7 @@ public class NumpyFormatTests extends BaseNd4jTest { } @Test() - public void testAbsentNumpyFile_1() throws Exception { + public void testAbsentNumpyFile_1(Nd4jBackend backend) throws Exception { assertThrows(IllegalArgumentException.class,() -> { val f = new File("pew-pew-zomg.some_extension_that_wont_exist"); INDArray act1 = Nd4j.createFromNpyFile(f); @@ -319,7 +328,7 @@ public class NumpyFormatTests extends BaseNd4jTest { } @Test() - public void testAbsentNumpyFile_2() throws Exception { + public void testAbsentNumpyFile_2(Nd4jBackend backend) throws Exception { assertThrows(IllegalArgumentException.class,() -> { val f = new File("c:/develop/batch-x-1.npy"); INDArray act1 = Nd4j.createFromNpyFile(f); @@ -330,7 +339,9 @@ public class NumpyFormatTests extends BaseNd4jTest { @Disabled @Test - public void testNumpyBoolean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNumpyBoolean(Nd4jBackend backend) { INDArray out = Nd4j.createFromNpyFile(new File("c:/Users/raver/Downloads/error2.npy")); // System.out.println(ArrayUtil.toList(ArrayUtil.toInts(out.shape()))); // System.out.println(out); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java index 14b0e858d..d51dbc2a7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.shape; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -37,19 +38,16 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class EmptyTests extends BaseNd4jTest { - DataType initialType; +public class EmptyTests extends BaseNd4jTestWithBackends { - public EmptyTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } + DataType initialType = Nd4j.dataType(); @Test - public void testEmpyArray_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmpyArray_1(Nd4jBackend backend) { val array = Nd4j.empty(); assertNotNull(array); @@ -69,7 +67,9 @@ public class EmptyTests extends BaseNd4jTest { @Test - public void testEmptyDtype_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyDtype_1(Nd4jBackend backend) { val array = Nd4j.empty(DataType.INT); assertTrue(array.isEmpty()); @@ -77,7 +77,9 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testEmptyDtype_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyDtype_2(Nd4jBackend backend) { val array = Nd4j.empty(DataType.LONG); assertTrue(array.isEmpty()); @@ -85,7 +87,9 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testConcat_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat_1(Nd4jBackend backend) { val row1 = Nd4j.create(new double[]{1, 1, 1, 1}, new long[]{1, 4}); val row2 = Nd4j.create(new double[]{2, 2, 2, 2}, new long[]{1, 4}); val row3 = Nd4j.create(new double[]{3, 3, 3, 3}, new long[]{1, 4}); @@ -105,7 +109,9 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testEmptyReductions(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyReductions(Nd4jBackend backend){ INDArray empty = Nd4j.empty(DataType.FLOAT); try { @@ -134,7 +140,9 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testGetEmpty(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetEmpty(Nd4jBackend backend){ INDArray empty = Nd4j.empty(DataType.FLOAT); try { empty.getFloat(0); @@ -156,7 +164,9 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testEmptyWithShape_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyWithShape_1(Nd4jBackend backend) { val array = Nd4j.create(DataType.FLOAT, 2, 0, 3); assertNotNull(array); @@ -168,7 +178,9 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testEmptyWithShape_2(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyWithShape_2(Nd4jBackend backend){ val array = Nd4j.create(DataType.FLOAT, 0); assertNotNull(array); @@ -181,7 +193,10 @@ public class EmptyTests extends BaseNd4jTest { } @Test() - public void testEmptyWithShape_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + + public void testEmptyWithShape_3(Nd4jBackend backend) { assertThrows(IllegalArgumentException.class,() -> { val array = Nd4j.create(DataType.FLOAT, 2, 0, 3); array.tensorAlongDimension(0, 2); @@ -190,7 +205,10 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testEmptyWithShape_4(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + + public void testEmptyWithShape_4(Nd4jBackend backend){ val array = Nd4j.create(DataType.FLOAT, 0, 3); assertNotNull(array); @@ -209,7 +227,9 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testEmptyReduction_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyReduction_1(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 2, 0, 3); val e = Nd4j.create(DataType.FLOAT, 2, 1, 3).assign(0); @@ -220,7 +240,9 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testEmptyReduction_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyReduction_2(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 2, 0, 3); val e = Nd4j.create(DataType.FLOAT, 2, 3).assign(0); @@ -232,7 +254,10 @@ public class EmptyTests extends BaseNd4jTest { @Test - public void testEmptyReduction_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + + public void testEmptyReduction_3(Nd4jBackend backend) { val x = Nd4j.create(DataType.FLOAT, 2, 0); val e = Nd4j.create(DataType.FLOAT, 0); @@ -243,21 +268,25 @@ public class EmptyTests extends BaseNd4jTest { } @Test() - public void testEmptyReduction_4() { - assertThrows(ND4JIllegalStateException.class,() -> { - val x = Nd4j.create(DataType.FLOAT, 2, 0); - val e = Nd4j.create(DataType.FLOAT, 0); + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyReduction_4(Nd4jBackend backend) { + assertThrows(ND4JIllegalStateException.class,() -> { + val x = Nd4j.create(DataType.FLOAT, 2, 0); + val e = Nd4j.create(DataType.FLOAT, 0); - val reduced = x.argMax(1); + val reduced = x.argMax(1); - assertArrayEquals(e.shape(), reduced.shape()); - assertEquals(e, reduced); - }); + assertArrayEquals(e.shape(), reduced.shape()); + assertEquals(e, reduced); + }); } @Test - public void testEmptyCreateMethods(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyCreateMethods(Nd4jBackend backend){ DataType dt = DataType.FLOAT; assertArrayEquals(new long[]{0}, Nd4j.create(0).shape()); assertArrayEquals(new long[]{0,0}, Nd4j.create(0,0).shape()); @@ -297,13 +326,18 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testEqualShapesEmpty(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + + public void testEqualShapesEmpty(Nd4jBackend backend){ assertTrue(Nd4j.create(0).equalShapes(Nd4j.create(0))); assertFalse(Nd4j.create(0).equalShapes(Nd4j.create(1, 0))); } @Test - public void testEmptyWhere() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEmptyWhere(Nd4jBackend backend) { val mask = Nd4j.createFromArray(false, false, false, false, false); val result = Nd4j.where(mask, null, null); @@ -312,7 +346,9 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testAllEmptyReduce(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllEmptyReduce(Nd4jBackend backend){ INDArray x = Nd4j.createFromArray(true, true, true); val all = new All(x); all.setEmptyReduce(true); //For TF compatibility - empty array for axis (which means no-op - and NOT all array reduction) @@ -321,7 +357,10 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testEmptyNoop() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + + public void testEmptyNoop(Nd4jBackend backend) { val output = Nd4j.empty(DataType.LONG); val op = DynamicCustomOp.builder("noop") @@ -332,7 +371,10 @@ public class EmptyTests extends BaseNd4jTest { } @Test - public void testEmptyConstructor_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + + public void testEmptyConstructor_1(Nd4jBackend backend) { val x = Nd4j.create(new double[0]); assertTrue(x.isEmpty()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java index 2db07226d..d5c650044 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/LongShapeTests.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.shape; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -32,16 +33,15 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -@RunWith(Parameterized.class) -public class LongShapeTests extends BaseNd4jTest { - public LongShapeTests(Nd4jBackend backend) { - super(backend); - } +public class LongShapeTests extends BaseNd4jTestWithBackends { + @Test - public void testLongBuffer_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLongBuffer_1(Nd4jBackend backend) { val exp = new long[]{2, 5, 3, 3, 1, 0, 1, 99}; val buffer = Nd4j.getDataBufferFactory().createLong(exp); @@ -52,7 +52,9 @@ public class LongShapeTests extends BaseNd4jTest { @Test - public void testLongShape_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLongShape_1(Nd4jBackend backend) { val exp = new long[]{2, 5, 3, 3, 1, 16384, 1, 99}; val array = Nd4j.createUninitialized(DataType.DOUBLE, 5, 3); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java index c7ba3e7b0..37f37f15b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/NDArrayMathTests.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.shape; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -37,16 +38,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class NDArrayMathTests extends BaseNd4jTest { - public NDArrayMathTests(Nd4jBackend backend) { - super(backend); - } +public class NDArrayMathTests extends BaseNd4jTestWithBackends { + @Test - public void testVectorPerSlice() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorPerSlice(Nd4jBackend backend) { INDArray arr = Nd4j.create(2, 2, 2, 2); assertEquals(4, NDArrayMath.vectorsPerSlice(arr)); @@ -59,20 +59,26 @@ public class NDArrayMathTests extends BaseNd4jTest { } @Test - public void testMatricesPerSlice() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatricesPerSlice(Nd4jBackend backend) { INDArray arr = Nd4j.create(2, 2, 2, 2); assertEquals(2, NDArrayMath.matricesPerSlice(arr)); } @Test - public void testLengthPerSlice() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLengthPerSlice(Nd4jBackend backend) { INDArray arr = Nd4j.create(2, 2, 2, 2); val lengthPerSlice = NDArrayMath.lengthPerSlice(arr); assertEquals(8, lengthPerSlice); } @Test - public void toffsetForSlice() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void toffsetForSlice(Nd4jBackend backend) { INDArray arr = Nd4j.create(3, 2, 2); int slice = 1; assertEquals(4, NDArrayMath.offsetForSlice(arr, slice)); @@ -80,13 +86,17 @@ public class NDArrayMathTests extends BaseNd4jTest { @Test - public void testMapOntoVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMapOntoVector(Nd4jBackend backend) { INDArray arr = Nd4j.create(3, 2, 2); assertEquals(NDArrayMath.mapIndexOntoVector(2, arr), 4); } @Test - public void testNumVectors() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNumVectors(Nd4jBackend backend) { INDArray arr = Nd4j.create(3, 2, 2); assertEquals(4, NDArrayMath.vectorsPerSlice(arr)); INDArray matrix = Nd4j.create(2, 2); @@ -95,7 +105,9 @@ public class NDArrayMathTests extends BaseNd4jTest { } @Test - public void testOffsetForSlice() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOffsetForSlice(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); int[] dimensions = {0, 1}; INDArray permuted = arr.permute(2, 3, 0, 1); @@ -131,14 +143,18 @@ public class NDArrayMathTests extends BaseNd4jTest { } @Test - public void testOddDimensions() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOddDimensions(Nd4jBackend backend) { INDArray arr = Nd4j.create(3, 2, 2); val numMatrices = NDArrayMath.matricesPerSlice(arr); assertEquals(1, numMatrices); } @Test - public void testTotalVectors() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTotalVectors(Nd4jBackend backend) { INDArray arr2 = Nd4j.create(2, 2, 2, 2); assertEquals(8, NDArrayMath.numVectors(arr2)); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java index 3e82c9844..539412bc0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.shape; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; @@ -33,12 +34,8 @@ import org.nd4j.common.util.ArrayUtil; import static org.junit.jupiter.api.Assertions.*; -@RunWith(Parameterized.class) -public class ShapeBufferTests extends BaseNd4jTest { - public ShapeBufferTests(Nd4jBackend backend) { - super(backend); - } +public class ShapeBufferTests extends BaseNd4jTestWithBackends { @Override public char ordering() { @@ -46,7 +43,9 @@ public class ShapeBufferTests extends BaseNd4jTest { } @Test - public void testRank() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRank(Nd4jBackend backend) { long[] shape = {2, 4}; long[] stride = {1, 2}; val shapeInfoBuffer = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false); @@ -56,7 +55,9 @@ public class ShapeBufferTests extends BaseNd4jTest { @Test - public void testArrCreationShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArrCreationShape(Nd4jBackend backend) { val arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); for (int i = 0; i < 2; i++) assertEquals(2, arr.size(i)); @@ -67,7 +68,9 @@ public class ShapeBufferTests extends BaseNd4jTest { } @Test - public void testShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShape(Nd4jBackend backend) { long[] shape = {2, 4}; long[] stride = {1, 2}; val shapeInfoBuffer = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false); @@ -84,7 +87,9 @@ public class ShapeBufferTests extends BaseNd4jTest { } @Test - public void testBuff() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBuff(Nd4jBackend backend) { long[] shape = {1, 2}; long[] stride = {1, 2}; val buff = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false).asNioLong(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java index df1cf9ae3..d8f9daeef 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTests.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.shape; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; @@ -42,15 +43,12 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class ShapeTests extends BaseNd4jTest { - public ShapeTests(Nd4jBackend backend) { - super(backend); - } - +public class ShapeTests extends BaseNd4jTestWithBackends { @Test - public void testRowColVectorVsScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRowColVectorVsScalar(Nd4jBackend backend) { INDArray arr = Nd4j.create(2); assertTrue(arr.isRowVector()); INDArray colVector = arr.reshape(2,1); @@ -61,10 +59,12 @@ public class ShapeTests extends BaseNd4jTest { INDArray arr3 = Nd4j.scalar(1.0); assertFalse(arr3.isColumnVector()); assertFalse(arr3.isRowVector()); - } + } @Test - public void testSixteenZeroOne() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSixteenZeroOne(Nd4jBackend backend) { INDArray baseArr = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); assertEquals(4, baseArr.tensorsAlongDimension(0, 1)); INDArray columnVectorFirst = Nd4j.create(new double[][] {{1, 3}, {2, 4}}); @@ -72,7 +72,7 @@ public class ShapeTests extends BaseNd4jTest { INDArray columnVectorThird = Nd4j.create(new double[][] {{5, 7}, {6, 8}}); INDArray columnVectorFourth = Nd4j.create(new double[][] {{13, 15}, {14, 16}}); INDArray[] assertions = - new INDArray[] {columnVectorFirst, columnVectorSecond, columnVectorThird, columnVectorFourth}; + new INDArray[] {columnVectorFirst, columnVectorSecond, columnVectorThird, columnVectorFourth}; for (int i = 0; i < baseArr.tensorsAlongDimension(0, 1); i++) { INDArray test = baseArr.tensorAlongDimension(i, 0, 1); @@ -82,7 +82,9 @@ public class ShapeTests extends BaseNd4jTest { } @Test - public void testVectorAlongDimension1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorAlongDimension1(Nd4jBackend backend) { INDArray arr = Nd4j.create(1, 5, 5); assertEquals(arr.vectorsAlongDimension(0), 5); assertEquals(arr.vectorsAlongDimension(1), 5); @@ -94,12 +96,14 @@ public class ShapeTests extends BaseNd4jTest { } @Test - public void testSixteenSecondDim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSixteenSecondDim(Nd4jBackend backend) { INDArray baseArr = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 5}), Nd4j.create(new double[] {9, 13}), - Nd4j.create(new double[] {3, 7}), Nd4j.create(new double[] {11, 15}), - Nd4j.create(new double[] {2, 6}), Nd4j.create(new double[] {10, 14}), - Nd4j.create(new double[] {4, 8}), Nd4j.create(new double[] {12, 16}), + Nd4j.create(new double[] {3, 7}), Nd4j.create(new double[] {11, 15}), + Nd4j.create(new double[] {2, 6}), Nd4j.create(new double[] {10, 14}), + Nd4j.create(new double[] {4, 8}), Nd4j.create(new double[] {12, 16}), }; @@ -113,7 +117,9 @@ public class ShapeTests extends BaseNd4jTest { @Test - public void testVectorAlongDimension() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorAlongDimension(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 24, 24, DataType.FLOAT).reshape(4, 3, 2); INDArray assertion = Nd4j.create(new float[] {5, 17}, new long[] {2}); INDArray vectorDimensionTest = arr.vectorAlongDimension(1, 2); @@ -144,11 +150,13 @@ public class ShapeTests extends BaseNd4jTest { } @Test - public void testThreeTwoTwo() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThreeTwoTwo(Nd4jBackend backend) { INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 4}), Nd4j.create(new double[] {7, 10}), - Nd4j.create(new double[] {2, 5}), Nd4j.create(new double[] {8, 11}), - Nd4j.create(new double[] {3, 6}), Nd4j.create(new double[] {9, 12}), + Nd4j.create(new double[] {2, 5}), Nd4j.create(new double[] {8, 11}), + Nd4j.create(new double[] {3, 6}), Nd4j.create(new double[] {9, 12}), }; @@ -161,18 +169,22 @@ public class ShapeTests extends BaseNd4jTest { } @Test - public void testNoCopy() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNoCopy(Nd4jBackend backend) { INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12, DataType.DOUBLE); INDArray arr = Shape.newShapeNoCopy(threeTwoTwo, new long[] {3, 2, 2}, true); assertArrayEquals(arr.shape(), new long[] {3, 2, 2}); } @Test - public void testThreeTwoTwoTwo() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThreeTwoTwoTwo(Nd4jBackend backend) { INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 7}), Nd4j.create(new double[] {4, 10}), - Nd4j.create(new double[] {2, 8}), Nd4j.create(new double[] {5, 11}), - Nd4j.create(new double[] {3, 9}), Nd4j.create(new double[] {6, 12}), + Nd4j.create(new double[] {2, 8}), Nd4j.create(new double[] {5, 11}), + Nd4j.create(new double[] {3, 9}), Nd4j.create(new double[] {6, 12}), }; @@ -185,7 +197,9 @@ public class ShapeTests extends BaseNd4jTest { } @Test - public void testNewAxis() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewAxis(Nd4jBackend backend) { INDArray tensor = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray assertion = Nd4j.create(new double[][] {{1, 7}, {4, 10}}).reshape(1, 2, 2); INDArray tensorGet = tensor.get(NDArrayIndex.point(0), NDArrayIndex.newAxis(), all(), all()); @@ -195,12 +209,14 @@ public class ShapeTests extends BaseNd4jTest { @Test - public void testSixteenFirstDim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSixteenFirstDim(Nd4jBackend backend) { INDArray baseArr = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 3}), Nd4j.create(new double[] {9, 11}), - Nd4j.create(new double[] {5, 7}), Nd4j.create(new double[] {13, 15}), - Nd4j.create(new double[] {2, 4}), Nd4j.create(new double[] {10, 12}), - Nd4j.create(new double[] {6, 8}), Nd4j.create(new double[] {14, 16}), + Nd4j.create(new double[] {5, 7}), Nd4j.create(new double[] {13, 15}), + Nd4j.create(new double[] {2, 4}), Nd4j.create(new double[] {10, 12}), + Nd4j.create(new double[] {6, 8}), Nd4j.create(new double[] {14, 16}), }; @@ -214,27 +230,31 @@ public class ShapeTests extends BaseNd4jTest { @Test - public void testDimShuffle() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDimShuffle(Nd4jBackend backend) { INDArray scalarTest = Nd4j.scalar(0.0).reshape(1, -1); INDArray broadcast = scalarTest.dimShuffle(new Object[] {'x'}, new long[] {0, 1}, new boolean[] {true, true}); assertTrue(broadcast.rank() == 3); INDArray rowVector = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1); assertEquals(rowVector, - rowVector.dimShuffle(new Object[] {0, 1}, new int[] {0, 1}, new boolean[] {false, false})); + rowVector.dimShuffle(new Object[] {0, 1}, new int[] {0, 1}, new boolean[] {false, false})); //add extra dimension to row vector in middle INDArray rearrangedRowVector = - rowVector.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {true, true}); + rowVector.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {true, true}); assertArrayEquals(new long[] {1, 1, 4}, rearrangedRowVector.shape()); INDArray dimshuffed = rowVector.dimShuffle(new Object[] {'x', 0, 'x', 'x'}, new long[] {0, 1}, - new boolean[] {true, true}); + new boolean[] {true, true}); assertArrayEquals(new long[] {1, 1, 1, 1, 4}, dimshuffed.shape()); } @Test - public void testEight() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEight(Nd4jBackend backend) { INDArray baseArr = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(2, 2, 2); assertEquals(2, baseArr.tensorsAlongDimension(0, 1)); INDArray columnVectorFirst = Nd4j.create(new double[][] {{1, 3}, {2, 4}}); @@ -244,6 +264,8 @@ public class ShapeTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBroadcastShapes(){ //Test cases: in1Shape, in2Shape, shapeOf(op(in1,in2)) List> testCases = new ArrayList<>(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java index 45dd3b447..9af908a1e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.shape; import lombok.val; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.iter.NdIndexIterator; @@ -38,19 +39,13 @@ import static org.junit.jupiter.api.Assertions.*; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class ShapeTestsC extends BaseNd4jTest { +public class ShapeTestsC extends BaseNd4jTestWithBackends { - public ShapeTestsC(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } - - DataType initialType; + DataType initialType = Nd4j.dataType(); @AfterEach - public void after() { + public void after(Nd4jBackend backend) { Nd4j.setDataType(this.initialType); } @@ -58,7 +53,9 @@ public class ShapeTestsC extends BaseNd4jTest { @Test - public void testSixteenZeroOne() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSixteenZeroOne(Nd4jBackend backend) { INDArray baseArr = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); assertEquals(4, baseArr.tensorsAlongDimension(0, 1)); INDArray columnVectorFirst = Nd4j.create(new double[][] {{1, 5}, {9, 13}}); @@ -66,7 +63,7 @@ public class ShapeTestsC extends BaseNd4jTest { INDArray columnVectorThird = Nd4j.create(new double[][] {{3, 7}, {11, 15}}); INDArray columnVectorFourth = Nd4j.create(new double[][] {{4, 8}, {12, 16}}); INDArray[] assertions = - new INDArray[] {columnVectorFirst, columnVectorSecond, columnVectorThird, columnVectorFourth}; + new INDArray[] {columnVectorFirst, columnVectorSecond, columnVectorThird, columnVectorFourth}; for (int i = 0; i < baseArr.tensorsAlongDimension(0, 1); i++) { INDArray test = baseArr.tensorAlongDimension(i, 0, 1); assertEquals( assertions[i], test,"Wrong at index " + i); @@ -75,12 +72,14 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testSixteenSecondDim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSixteenSecondDim(Nd4jBackend backend) { INDArray baseArr = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 3}), Nd4j.create(new double[] {2, 4}), - Nd4j.create(new double[] {5, 7}), Nd4j.create(new double[] {6, 8}), - Nd4j.create(new double[] {9, 11}), Nd4j.create(new double[] {10, 12}), - Nd4j.create(new double[] {13, 15}), Nd4j.create(new double[] {14, 16}), + Nd4j.create(new double[] {5, 7}), Nd4j.create(new double[] {6, 8}), + Nd4j.create(new double[] {9, 11}), Nd4j.create(new double[] {10, 12}), + Nd4j.create(new double[] {13, 15}), Nd4j.create(new double[] {14, 16}), }; @@ -93,11 +92,13 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testThreeTwoTwo() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThreeTwoTwo(Nd4jBackend backend) { INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 3}), Nd4j.create(new double[] {2, 4}), - Nd4j.create(new double[] {5, 7}), Nd4j.create(new double[] {6, 8}), - Nd4j.create(new double[] {9, 11}), Nd4j.create(new double[] {10, 12}), + Nd4j.create(new double[] {5, 7}), Nd4j.create(new double[] {6, 8}), + Nd4j.create(new double[] {9, 11}), Nd4j.create(new double[] {10, 12}), }; @@ -110,11 +111,13 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testThreeTwoTwoTwo() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThreeTwoTwoTwo(Nd4jBackend backend) { INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 2}), Nd4j.create(new double[] {3, 4}), - Nd4j.create(new double[] {5, 6}), Nd4j.create(new double[] {7, 8}), - Nd4j.create(new double[] {9, 10}), Nd4j.create(new double[] {11, 12}), + Nd4j.create(new double[] {5, 6}), Nd4j.create(new double[] {7, 8}), + Nd4j.create(new double[] {9, 10}), Nd4j.create(new double[] {11, 12}), }; @@ -126,7 +129,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testPutRow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutRow(Nd4jBackend backend) { INDArray matrix = Nd4j.create(new double[][] {{1, 2}, {3, 4}}); for (int i = 0; i < matrix.rows(); i++) { INDArray row = matrix.getRow(i); @@ -139,12 +144,14 @@ public class ShapeTestsC extends BaseNd4jTest { @Test - public void testSixteenFirstDim() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSixteenFirstDim(Nd4jBackend backend) { INDArray baseArr = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 5}), Nd4j.create(new double[] {2, 6}), - Nd4j.create(new double[] {3, 7}), Nd4j.create(new double[] {4, 8}), - Nd4j.create(new double[] {9, 13}), Nd4j.create(new double[] {10, 14}), - Nd4j.create(new double[] {11, 15}), Nd4j.create(new double[] {12, 16}), + Nd4j.create(new double[] {3, 7}), Nd4j.create(new double[] {4, 8}), + Nd4j.create(new double[] {9, 13}), Nd4j.create(new double[] {10, 14}), + Nd4j.create(new double[] {11, 15}), Nd4j.create(new double[] {12, 16}), }; @@ -157,7 +164,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testReshapePermute() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshapePermute(Nd4jBackend backend) { INDArray arrNoPermute = Nd4j.ones(DataType.DOUBLE,5, 3, 4); INDArray reshaped2dNoPermute = arrNoPermute.reshape(5 * 3, 4); //OK assertArrayEquals(reshaped2dNoPermute.shape(), new long[] {5 * 3, 4}); @@ -171,7 +180,9 @@ public class ShapeTestsC extends BaseNd4jTest { @Test - public void testEight() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEight(Nd4jBackend backend) { INDArray baseArr = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(2, 2, 2); assertEquals(2, baseArr.tensorsAlongDimension(0, 1)); INDArray columnVectorFirst = Nd4j.create(new double[][] {{1, 3}, {5, 7}}); @@ -185,7 +196,9 @@ public class ShapeTestsC extends BaseNd4jTest { @Test - public void testOtherReshape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOtherReshape(Nd4jBackend backend) { INDArray nd = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, new long[] {2, 3}); INDArray slice = nd.slice(1, 0); @@ -198,7 +211,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testVectorAlongDimension() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorAlongDimension(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); INDArray assertion = Nd4j.create(new double[] {3, 4}); INDArray vectorDimensionTest = arr.vectorAlongDimension(1, 2); @@ -249,9 +264,9 @@ public class ShapeTestsC extends BaseNd4jTest { INDArray fourdTest = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); double[][] assertionsArr = - new double[][] {{1, 3}, {2, 4}, {5, 7}, {6, 8}, {9, 11}, {10, 12}, {13, 15}, {14, 16}, + new double[][] {{1, 3}, {2, 4}, {5, 7}, {6, 8}, {9, 11}, {10, 12}, {13, 15}, {14, 16}, - }; + }; assertEquals(assertionsArr.length, fourdTest.vectorsAlongDimension(2)); @@ -267,7 +282,9 @@ public class ShapeTestsC extends BaseNd4jTest { @Test - public void testColumnSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testColumnSum(Nd4jBackend backend) { INDArray twoByThree = Nd4j.linspace(1, 600, 600, DataType.FLOAT).reshape(150, 4); INDArray columnVar = twoByThree.sum(0); INDArray assertion = Nd4j.create(new float[] {44850.0f, 45000.0f, 45150.0f, 45300.0f}); @@ -276,7 +293,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testRowMean() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRowMean(Nd4jBackend backend) { INDArray twoByThree = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray rowMean = twoByThree.mean(1); INDArray assertion = Nd4j.create(new double[] {1.5, 3.5}); @@ -286,7 +305,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testRowStd() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRowStd(Nd4jBackend backend) { INDArray twoByThree = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray rowStd = twoByThree.std(1); INDArray assertion = Nd4j.create(new double[] {0.7071067811865476f, 0.7071067811865476f}); @@ -296,7 +317,9 @@ public class ShapeTestsC extends BaseNd4jTest { @Test - public void testColumnSumDouble() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testColumnSumDouble(Nd4jBackend backend) { DataType initialType = Nd4j.dataType(); DataTypeUtil.setDTypeForContext(DataType.DOUBLE); INDArray twoByThree = Nd4j.linspace(1, 600, 600, DataType.DOUBLE).reshape(150, 4); @@ -308,7 +331,9 @@ public class ShapeTestsC extends BaseNd4jTest { @Test - public void testColumnVariance() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testColumnVariance(Nd4jBackend backend) { INDArray twoByThree = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray columnVar = twoByThree.var(true, 0); INDArray assertion = Nd4j.create(new double[] {2, 2}); @@ -318,7 +343,9 @@ public class ShapeTestsC extends BaseNd4jTest { @Test - public void testCumSum() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCumSum(Nd4jBackend backend) { INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {1, 4}); INDArray cumSumAnswer = Nd4j.create(new double[] {1, 3, 6, 10}, new long[] {1, 4}); INDArray cumSumTest = n.cumsum(0); @@ -327,7 +354,7 @@ public class ShapeTestsC extends BaseNd4jTest { INDArray n2 = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); INDArray axis0assertion = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, - 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0}, n2.shape()); + 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0}, n2.shape()); INDArray axis0Test = n2.cumsum(0); assertEquals(axis0assertion, axis0Test,getFailureMessage()); @@ -335,7 +362,9 @@ public class ShapeTestsC extends BaseNd4jTest { @Test - public void testSumRow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSumRow(Nd4jBackend backend) { INDArray rowVector10 = Nd4j.ones(DataType.DOUBLE,1,10); INDArray sum1 = rowVector10.sum(1); assertArrayEquals(new long[] {1}, sum1.shape()); @@ -343,7 +372,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testSumColumn() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSumColumn(Nd4jBackend backend) { INDArray colVector10 = Nd4j.ones(10, 1); INDArray sum0 = colVector10.sum(0); assertArrayEquals( new long[] {1}, sum0.shape()); @@ -351,7 +382,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testSum2d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum2d(Nd4jBackend backend) { INDArray arr = Nd4j.ones(10, 10); INDArray sum0 = arr.sum(0); assertArrayEquals(new long[] {10}, sum0.shape()); @@ -361,7 +394,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testSum2dv2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSum2dv2(Nd4jBackend backend) { INDArray arr = Nd4j.ones(10, 10); INDArray sumBoth = arr.sum(0, 1); assertArrayEquals(new long[0], sumBoth.shape()); @@ -369,7 +404,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testPermuteReshape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPermuteReshape(Nd4jBackend backend) { INDArray arrTest = Nd4j.arange(60).reshape('c', 3, 4, 5); INDArray permute = arrTest.permute(2, 1, 0); assertArrayEquals(new long[] {5, 4, 3}, permute.shape()); @@ -381,7 +418,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testRavel() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRavel(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray asseriton = Nd4j.linspace(1, 4, 4); INDArray raveled = linspace.ravel(); @@ -395,11 +434,13 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testPutScalar() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPutScalar(Nd4jBackend backend) { //Check that the various putScalar methods have the same result... val shapes = new int[][] {{3, 4}, {1, 4}, {3, 1}, {3, 4, 5}, {1, 4, 5}, {3, 1, 5}, {3, 4, 1}, {1, 1, 5}, - {3, 4, 5, 6}, {1, 4, 5, 6}, {3, 1, 5, 6}, {3, 4, 1, 6}, {3, 4, 5, 1}, {1, 1, 5, 6}, - {3, 1, 1, 6}, {3, 1, 1, 1}}; + {3, 4, 5, 6}, {1, 4, 5, 6}, {3, 1, 5, 6}, {3, 4, 1, 6}, {3, 4, 5, 1}, {1, 1, 5, 6}, + {3, 1, 1, 6}, {3, 1, 1, 1}}; for (int[] shape : shapes) { int rank = shape.length; @@ -441,7 +482,9 @@ public class ShapeTestsC extends BaseNd4jTest { @Test - public void testReshapeToTrueScalar_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshapeToTrueScalar_1(Nd4jBackend backend) { val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1}); val exp = Nd4j.scalar(1.0f); @@ -454,7 +497,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testReshapeToTrueScalar_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshapeToTrueScalar_2(Nd4jBackend backend) { val orig = Nd4j.create(new float[]{1.0f}, new int[]{1}); val exp = Nd4j.scalar(1.0f); @@ -467,7 +512,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testReshapeToTrueScalar_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshapeToTrueScalar_3(Nd4jBackend backend) { val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1}); val exp = Nd4j.createFromArray(new float[]{1.0f}); @@ -480,7 +527,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testReshapeToTrueScalar_4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReshapeToTrueScalar_4(Nd4jBackend backend) { val orig = Nd4j.create(new float[]{1.0f}, new int[]{1, 1}); val exp = Nd4j.scalar(1.0f); @@ -493,7 +542,9 @@ public class ShapeTestsC extends BaseNd4jTest { } @Test - public void testViewAfterReshape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testViewAfterReshape(Nd4jBackend backend) { val x = Nd4j.rand(3,4); val x2 = x.ravel(); val x3 = x.reshape(6,2); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java index 7a7386f9d..43b3d83e5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/StaticShapeTests.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.shape; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; @@ -44,16 +45,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class StaticShapeTests extends BaseNd4jTest { - - public StaticShapeTests(Nd4jBackend backend) { - super(backend); - } +public class StaticShapeTests extends BaseNd4jTestWithBackends { @Test - public void testShapeInd2Sub() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShapeInd2Sub(Nd4jBackend backend) { long normalTotal = 0; long n = 1000; for (int i = 0; i < n; i++) { @@ -72,7 +70,9 @@ public class StaticShapeTests extends BaseNd4jTest { @Test - public void testBufferToIntShapeStrideMethods() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBufferToIntShapeStrideMethods(Nd4jBackend backend) { //Specifically: Shape.shape(IntBuffer), Shape.shape(DataBuffer) //.isRowVectorShape(DataBuffer), .isRowVectorShape(IntBuffer) //Shape.size(DataBuffer,int), Shape.size(IntBuffer,int) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java index c85264965..0a7d9a731 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.shape; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -42,15 +43,14 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.point; @Slf4j -@RunWith(Parameterized.class) -public class TADTests extends BaseNd4jTest { - public TADTests(Nd4jBackend backend) { - super(backend); - } +public class TADTests extends BaseNd4jTestWithBackends { + @Test - public void testStall() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStall(Nd4jBackend backend) { //[4, 3, 3, 4, 5, 60, 20, 5, 1, 0, 1, 99], dimensions: [1, 2, 3] INDArray arr = Nd4j.create(3, 3, 4, 5); arr.tensorAlongDimension(0, 1, 2, 3); @@ -64,13 +64,15 @@ public class TADTests extends BaseNd4jTest { * @throws Exception */ @Test - public void testEquality1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testEquality1(Nd4jBackend backend) { char[] order = new char[] {'c', 'f'}; int[] dim_e = new int[] {0, 2}; int[] dim_x = new int[] {1, 3}; List dim_3 = Arrays.asList(new int[] {0, 2, 3}, new int[] {0, 1, 2}, new int[] {1, 2, 3}, - new int[] {0, 1, 3}); + new int[] {0, 1, 3}); for (char o : order) { @@ -119,15 +121,17 @@ public class TADTests extends BaseNd4jTest { } @Test - public void testMysteriousCrash() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMysteriousCrash(Nd4jBackend backend) { INDArray arrayF = Nd4j.create(new int[] {1, 1, 4, 4}, 'f'); INDArray arrayC = Nd4j.create(new int[] {1, 1, 4, 4}, 'c'); INDArray javaCTad = arrayC.tensorAlongDimension(0, 2, 3); INDArray javaFTad = arrayF.tensorAlongDimension(0, 2, 3); Pair tadBuffersF = - Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(arrayF, 2, 3); + Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(arrayF, 2, 3); Pair tadBuffersC = - Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(arrayC, 2, 3); + Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(arrayC, 2, 3); // log.info("Got TADShapeF: {}", Arrays.toString(tadBuffersF.getFirst().asInt()) + " with java " // + javaFTad.shapeInfoDataBuffer()); @@ -136,6 +140,8 @@ public class TADTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTADEWSStride(){ INDArray orig = Nd4j.linspace(1, 600, 600).reshape('f', 10, 1, 60); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java index 846166367..155a900e7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTests.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.shape.concat; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; @@ -45,16 +46,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue; * @author Adam Gibson */ @Slf4j -@RunWith(Parameterized.class) -public class ConcatTests extends BaseNd4jTest { - public ConcatTests(Nd4jBackend backend) { - super(backend); - } +public class ConcatTests extends BaseNd4jTestWithBackends { + @Test - public void testConcat() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat(Nd4jBackend backend) { INDArray A = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(2, 2, 2); INDArray B = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray concat = Nd4j.concat(0, A, B); @@ -63,7 +63,9 @@ public class ConcatTests extends BaseNd4jTest { } @Test - public void testConcatHorizontally() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatHorizontally(Nd4jBackend backend) { INDArray rowVector = Nd4j.ones(1, 5); INDArray other = Nd4j.ones(1, 5); INDArray concat = Nd4j.hstack(other, rowVector); @@ -74,7 +76,9 @@ public class ConcatTests extends BaseNd4jTest { @Test - public void testVStackColumn() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVStackColumn(Nd4jBackend backend) { INDArray linspaced = Nd4j.linspace(1, 3, 3, DataType.DOUBLE).reshape(3, 1); INDArray stacked = linspaced.dup(); INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 1, 2, 3}, new int[] {6, 1}); @@ -84,7 +88,9 @@ public class ConcatTests extends BaseNd4jTest { @Test - public void testConcatScalars() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatScalars(Nd4jBackend backend) { INDArray first = Nd4j.arange(0, 1).reshape(1, 1); INDArray second = Nd4j.arange(0, 1).reshape(1, 1); INDArray firstRet = Nd4j.concat(0, first, second); @@ -95,7 +101,9 @@ public class ConcatTests extends BaseNd4jTest { @Test - public void testConcatMatrices() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatMatrices(Nd4jBackend backend) { INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray b = a.dup(); @@ -110,7 +118,9 @@ public class ConcatTests extends BaseNd4jTest { } @Test - public void testConcatRowVectors() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatRowVectors(Nd4jBackend backend) { INDArray rowVector = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, new int[] {1, 6}); INDArray matrix = Nd4j.create(new double[] {7, 8, 9, 10, 11, 12}, new int[] {1, 6}); @@ -125,7 +135,9 @@ public class ConcatTests extends BaseNd4jTest { @Test - public void testConcat3d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat3d(Nd4jBackend backend) { INDArray first = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape('c', 2, 3, 4); INDArray second = Nd4j.linspace(24, 36, 12, DataType.DOUBLE).reshape('c', 1, 3, 4); INDArray third = Nd4j.linspace(36, 48, 12, DataType.DOUBLE).reshape('c', 1, 3, 4); @@ -172,7 +184,9 @@ public class ConcatTests extends BaseNd4jTest { @Test @Disabled - public void testConcat3dv2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat3dv2(Nd4jBackend backend) { INDArray first = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape('c', 2, 3, 4); INDArray second = Nd4j.linspace(24, 35, 12, DataType.DOUBLE).reshape('c', 1, 3, 4); @@ -254,6 +268,8 @@ public class ConcatTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void concatf(){ char orderBefore = Nd4j.order(); try { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java index 391d1fec7..6af498231 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java @@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; @@ -48,16 +49,15 @@ import static org.junit.jupiter.api.Assertions.*; * @author Adam Gibson */ @Slf4j -@RunWith(Parameterized.class) -public class ConcatTestsC extends BaseNd4jTest { - public ConcatTestsC(Nd4jBackend backend) { - super(backend); - } +public class ConcatTestsC extends BaseNd4jTestWithBackends { + @Test - public void testConcatVertically() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatVertically(Nd4jBackend backend) { INDArray rowVector = Nd4j.ones(1, 5); INDArray other = Nd4j.ones(1, 5); INDArray concat = Nd4j.vstack(other, rowVector); @@ -79,7 +79,9 @@ public class ConcatTestsC extends BaseNd4jTest { @Test - public void testConcatScalars() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatScalars(Nd4jBackend backend) { INDArray first = Nd4j.arange(0, 1).reshape(1, 1); INDArray second = Nd4j.arange(0, 1).reshape(1, 1); INDArray firstRet = Nd4j.concat(0, first, second); @@ -89,7 +91,9 @@ public class ConcatTestsC extends BaseNd4jTest { } @Test - public void testConcatScalars1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatScalars1(Nd4jBackend backend) { INDArray first = Nd4j.scalar(1); INDArray second = Nd4j.scalar(2); INDArray third = Nd4j.scalar(3); @@ -102,7 +106,9 @@ public class ConcatTestsC extends BaseNd4jTest { } @Test - public void testConcatVectors1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatVectors1(Nd4jBackend backend) { INDArray first = Nd4j.ones(1, 10); INDArray second = Nd4j.ones(1, 10); INDArray third = Nd4j.ones(1, 10); @@ -120,7 +126,9 @@ public class ConcatTestsC extends BaseNd4jTest { } @Test - public void testConcatMatrices() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatMatrices(Nd4jBackend backend) { INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray b = a.dup(); @@ -139,7 +147,9 @@ public class ConcatTestsC extends BaseNd4jTest { } @Test - public void testAssign() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssign(Nd4jBackend backend) { INDArray vector = Nd4j.linspace(1, 5, 5, Nd4j.dataType()); vector.assign(1); assertEquals(Nd4j.ones(5), vector); @@ -156,7 +166,9 @@ public class ConcatTestsC extends BaseNd4jTest { } @Test - public void testConcatRowVectors() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatRowVectors(Nd4jBackend backend) { INDArray rowVector = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, new int[] {1, 6}); INDArray matrix = Nd4j.create(new double[] {7, 8, 9, 10, 11, 12}, new int[] {1, 6}); @@ -171,7 +183,9 @@ public class ConcatTestsC extends BaseNd4jTest { @Test - public void testConcat3d() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat3d(Nd4jBackend backend) { INDArray first = Nd4j.linspace(1, 24, 24, Nd4j.dataType()).reshape('c', 2, 3, 4); INDArray second = Nd4j.linspace(24, 36, 12, Nd4j.dataType()).reshape('c', 1, 3, 4); INDArray third = Nd4j.linspace(36, 48, 12, Nd4j.dataType()).reshape('c', 1, 3, 4); @@ -218,7 +232,9 @@ public class ConcatTestsC extends BaseNd4jTest { } @Test() - public void testConcatVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcatVector(Nd4jBackend backend) { assertThrows(ND4JIllegalStateException.class,() -> { Nd4j.concat(0, Nd4j.ones(1,1000000), Nd4j.create(1, 1)); @@ -227,7 +243,9 @@ public class ConcatTestsC extends BaseNd4jTest { @Test @Disabled - public void testConcat3dv2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testConcat3dv2(Nd4jBackend backend) { INDArray first = Nd4j.linspace(1, 24, 24).reshape('c', 2, 3, 4); INDArray second = Nd4j.linspace(24, 35, 12).reshape('c', 1, 3, 4); @@ -311,7 +329,9 @@ public class ConcatTestsC extends BaseNd4jTest { @Test - public void testLargeConcat() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLargeConcat(Nd4jBackend backend) { val list = new ArrayList(); for (int e = 0; e < 20000; e++) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java index 47867eae1..b387c870d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTests.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.shape.concat.padding; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -35,17 +36,16 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class PaddingTests extends BaseNd4jTest { - public PaddingTests(Nd4jBackend backend) { - super(backend); - } +public class PaddingTests extends BaseNd4jTestWithBackends { + @Test - public void testAppend() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAppend(Nd4jBackend backend) { INDArray appendTo = Nd4j.ones(DataType.DOUBLE,3, 3); INDArray ret = Nd4j.append(appendTo, 3, 1, -1); assertArrayEquals(new long[] {3, 6}, ret.shape()); @@ -60,7 +60,9 @@ public class PaddingTests extends BaseNd4jTest { } @Test - public void testPrepend() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPrepend(Nd4jBackend backend) { INDArray appendTo = Nd4j.ones(DataType.DOUBLE, 3, 3); INDArray ret = Nd4j.append(appendTo, 3, 1, -1); assertArrayEquals(new long[] {3, 6}, ret.shape()); @@ -76,17 +78,19 @@ public class PaddingTests extends BaseNd4jTest { @Test - public void testPad() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPad(Nd4jBackend backend) { INDArray start = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray ret = Nd4j.pad(start, 5, 5); double[][] data = new double[][] {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, - {0, 0, 0, 0, 0, 1, 4, 7, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 2, 5, 8, 0, 0, 0, 0, 0.}, - {0, 0, 0, 0, 0, 3, 6, 9, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}}; + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, + {0, 0, 0, 0, 0, 1, 4, 7, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 2, 5, 8, 0, 0, 0, 0, 0.}, + {0, 0, 0, 0, 0, 3, 6, 9, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.}}; INDArray assertion = Nd4j.create(data); assertEquals(assertion, ret); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java index 055185e58..d9ec9d7a5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/padding/PaddingTestsC.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.shape.concat.padding; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.convolution.Convolution; @@ -37,11 +38,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class PaddingTestsC extends BaseNd4jTest { - public PaddingTestsC(Nd4jBackend backend) { - super(backend); - } + +public class PaddingTestsC extends BaseNd4jTestWithBackends { @Override public char ordering() { @@ -49,7 +47,9 @@ public class PaddingTestsC extends BaseNd4jTest { } @Test - public void testPrepend() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPrepend(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray assertion = Nd4j.create(new double[][] {{1, 1, 1, 1, 2}, {1, 1, 1, 3, 4}}); @@ -61,34 +61,38 @@ public class PaddingTestsC extends BaseNd4jTest { @Test - public void testPaddingOneThrougFour() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPaddingOneThrougFour(Nd4jBackend backend) { int ph = 0; int pw = 0; int sy = 2; int sx = 2; INDArray ret = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4}, new int[] {1, 1, 8, 8}); + 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4}, new int[] {1, 1, 8, 8}); INDArray padded = Nd4j.pad(ret, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}}); INDArray assertion = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 4, 4, 4, 4, 4, 4, 4, 4, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - new int[] {1, 1, 9, 9}); + new int[] {1, 1, 9, 9}); assertArrayEquals(assertion.shape(), padded.shape()); assertEquals(assertion, padded); } @Test - public void testAppend2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAppend2(Nd4jBackend backend) { INDArray ret = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4}, new int[] {1, 1, 8, 8}); + 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4}, new int[] {1, 1, 8, 8}); INDArray appendAssertion = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, - 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, - 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0}, new int[] {1, 1, 9, 8}); + 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, + 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0}, new int[] {1, 1, 9, 8}); INDArray appended = Nd4j.append(ret, 1, 0, 2); assertArrayEquals(appendAssertion.shape(), appended.shape()); @@ -96,7 +100,9 @@ public class PaddingTestsC extends BaseNd4jTest { } @Test - public void testPaddingTensor() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPaddingTensor(Nd4jBackend backend) { //,1,1,1,1,2,2,0 int kh = 1, kw = 1, sy = 1, sx = 1, ph = 2, pw = 2; INDArray linspaced = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); @@ -114,7 +120,9 @@ public class PaddingTestsC extends BaseNd4jTest { @Test - public void testAppend() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAppend(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray otherAppend = Nd4j.append(linspace, 3, 1.0, -1); INDArray assertion = Nd4j.create(new double[][] {{1, 2, 1, 1, 1}, {3, 4, 1, 1, 1}}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java index 522b0fe2c..af8209de3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTests.java @@ -25,10 +25,10 @@ import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.rules.ErrorCollector; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -43,16 +43,14 @@ import static org.junit.jupiter.api.Assertions.*; * @author Adam Gibson */ @Slf4j -@RunWith(Parameterized.class) -public class IndexingTests extends BaseNd4jTest { + +public class IndexingTests extends BaseNd4jTestWithBackends { - public IndexingTests(Nd4jBackend backend) { - super(backend); - } - - @Test - public void testGet() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGet(Nd4jBackend backend) { // System.out.println("Testing sub-array put and get with a 3D array ..."); INDArray arr = Nd4j.linspace(0, 124, 125).reshape(5, 5, 5); @@ -112,8 +110,10 @@ public class IndexingTests extends BaseNd4jTest { /* Simple test that checks indexing through different ways that fails */ - @Test - public void testSimplePoint() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSimplePoint(Nd4jBackend backend) { INDArray A = Nd4j.linspace(1, 3 * 3 * 3, 3 * 3 * 3).reshape(3, 3, 3); /* @@ -143,8 +143,10 @@ public class IndexingTests extends BaseNd4jTest { This is the same as the above test - just tests every possible window with a slice from the 0th dim They all fail - so it's possibly unrelated to the value of the index */ - @Test - public void testPointIndexing() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPointIndexing(Nd4jBackend backend) { int slices = 5; int rows = 5; int cols = 5; @@ -177,7 +179,7 @@ public class IndexingTests extends BaseNd4jTest { // The test .equals fails on a comparison of row vs column vector. //TODO: possibly figure out what's going on here at some point? // - Adam - public void testTensorGet() { + public void testTensorGet(Nd4jBackend backend) { INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12).reshape(3, 2, 2); /* * [[[ 1., 7.], @@ -198,8 +200,10 @@ public class IndexingTests extends BaseNd4jTest { assertEquals(secondAssertion, secondTest); } - @Test - public void concatGetBug() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void concatGetBug(Nd4jBackend backend) { int width = 5; int height = 4; int depth = 3; @@ -223,8 +227,10 @@ public class IndexingTests extends BaseNd4jTest { assertEquals(second, get); //Fails } - @Test - public void testShape() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShape(Nd4jBackend backend) { INDArray ndarray = Nd4j.create(new float[][] {{1f, 2f}, {3f, 4f}}); INDArray subarray = ndarray.get(NDArrayIndex.point(0), NDArrayIndex.all()); assertTrue(subarray.isRowVector()); @@ -232,8 +238,10 @@ public class IndexingTests extends BaseNd4jTest { assertArrayEquals(new long[]{2}, shape); } - @Test - public void testGetRows() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRows(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray testAssertion = Nd4j.create(new double[][] {{5, 8}, {6, 9}}); @@ -242,8 +250,10 @@ public class IndexingTests extends BaseNd4jTest { } - @Test - public void testFirstColumn() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFirstColumn(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[][] {{5, 6}, {7, 8}}); INDArray assertion = Nd4j.create(new double[] {5, 7}); @@ -252,8 +262,10 @@ public class IndexingTests extends BaseNd4jTest { } - @Test - public void testLinearIndex() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLinearIndex(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4).reshape(2, 2); for (int i = 0; i < linspace.length(); i++) { assertEquals(i + 1, linspace.getDouble(i), 1e-1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java index d6507d6ce..12b3c3b88 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java @@ -24,10 +24,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.rules.ErrorCollector; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd; @@ -43,16 +43,15 @@ import static org.junit.jupiter.api.Assertions.*; * @author Adam Gibson */ @Slf4j -@RunWith(Parameterized.class) -public class IndexingTestsC extends BaseNd4jTest { + +public class IndexingTestsC extends BaseNd4jTestWithBackends { - public IndexingTestsC(Nd4jBackend backend) { - super(backend); - } - @Test - public void testExecSubArray() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testExecSubArray(Nd4jBackend backend) { INDArray nd = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, new int[] {2, 3}); INDArray sub = nd.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)); @@ -62,16 +61,20 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test - public void testLinearViewElementWiseMatching() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLinearViewElementWiseMatching(Nd4jBackend backend) { INDArray linspace = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray dup = linspace.dup(); linspace.addi(dup); } - @Test - public void testGetRows() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRows(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray testAssertion = Nd4j.create(new double[][] {{4, 5}, {7, 8}}); @@ -80,8 +83,10 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test - public void testFirstColumn() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFirstColumn(Nd4jBackend backend) { INDArray arr = Nd4j.create(new double[][] {{5, 7}, {6, 8}}); INDArray assertion = Nd4j.create(new double[] {5, 6}); @@ -89,8 +94,10 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(assertion, test); } - @Test - public void testMultiRow() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultiRow(Nd4jBackend backend) { INDArray matrix = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray assertion = Nd4j.create(new double[][] {{4, 7}}); @@ -98,8 +105,10 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(assertion, test); } - @Test - public void testPointIndexes() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPointIndexes(Nd4jBackend backend) { INDArray arr = Nd4j.create(DataType.DOUBLE, 4, 3, 2); INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all()); assertArrayEquals(new long[] {4, 2}, get.shape()); @@ -115,8 +124,10 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(assertion, linspacedGet); } - @Test - public void testGetWithVariedStride() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetWithVariedStride(Nd4jBackend backend) { int ph = 0; int pw = 0; int sy = 2; @@ -165,8 +176,10 @@ public class IndexingTestsC extends BaseNd4jTest { } - @Test - public void testRowVectorInterval() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testRowVectorInterval(Nd4jBackend backend) { int len = 30; INDArray row = Nd4j.zeros(1, len); for (int i = 0; i < len; i++) { @@ -194,8 +207,10 @@ public class IndexingTestsC extends BaseNd4jTest { assertTrue(last10b.getDouble(i) == 20 + i); } - @Test - public void test1dSubarray_1() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test1dSubarray_1(Nd4jBackend backend) { val data = Nd4j.linspace(DataType.FLOAT,0, 10, 1); val exp = Nd4j.createFromArray(new float[]{3.f, 4.f}); val dataAtIndex = data.get(NDArrayIndex.interval(3, 5)); @@ -203,8 +218,10 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(exp, dataAtIndex); } - @Test - public void test1dSubarray_2() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test1dSubarray_2(Nd4jBackend backend) { val data = Nd4j.linspace(DataType.FLOAT,1, 10, 1); val exp = Nd4j.createFromArray(new float[]{4.f, 6.f}); val dataAtIndex = data.get(Nd4j.createFromArray(new int[]{3, 5})); @@ -212,8 +229,10 @@ public class IndexingTestsC extends BaseNd4jTest { assertEquals(exp, dataAtIndex); } - @Test - public void testGet() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGet(Nd4jBackend backend) { // System.out.println("Testing sub-array put and get with a 3D array ..."); INDArray arr = Nd4j.linspace(0, 124, 125).reshape(5, 5, 5); @@ -269,8 +288,10 @@ public class IndexingTestsC extends BaseNd4jTest { // System.out.println("... done"); } - @Test - public void testSimplePoint() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSimplePoint(Nd4jBackend backend) { INDArray A = Nd4j.linspace(1, 3 * 3 * 3, 3 * 3 * 3).reshape(3, 3, 3); /* @@ -295,8 +316,10 @@ public class IndexingTestsC extends BaseNd4jTest { This is the same as the above test - just tests every possible window with a slice from the 0th dim They all fail - so it's possibly unrelated to the value of the index */ - @Test - public void testPointIndexing() { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPointIndexing(Nd4jBackend backend) { int slices = 5; int rows = 5; int cols = 5; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java index eb691afef..c7f63053e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnes.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.shape.ones; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -37,15 +38,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class LeadingAndTrailingOnes extends BaseNd4jTest { - public LeadingAndTrailingOnes(Nd4jBackend backend) { - super(backend); - } +public class LeadingAndTrailingOnes extends BaseNd4jTestWithBackends { + @Test - public void testSliceConstructor() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSliceConstructor(Nd4jBackend backend) { List testList = new ArrayList<>(); for (int i = 0; i < 5; i++) testList.add(Nd4j.scalar(DataType.DOUBLE, i + 1)); @@ -56,7 +56,9 @@ public class LeadingAndTrailingOnes extends BaseNd4jTest { } @Test - public void testLeadAndTrail() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLeadAndTrail(Nd4jBackend backend) { INDArray fourD = Nd4j.create(1, 2, 1, 1); assertEquals(2, fourD.length()); for (int i = 0; i < fourD.length(); i++) @@ -65,7 +67,9 @@ public class LeadingAndTrailingOnes extends BaseNd4jTest { } @Test - public void testCreateLeadingAndTrailingOnes() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateLeadingAndTrailingOnes(Nd4jBackend backend) { INDArray arr = Nd4j.create(1, 10, 1, 1); arr.assign(1); arr.toString(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java index cf9a1a9b3..424b181be 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ones/LeadingAndTrailingOnesC.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.shape.ones; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -34,22 +35,23 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class LeadingAndTrailingOnesC extends BaseNd4jTest { - public LeadingAndTrailingOnesC(Nd4jBackend backend) { - super(backend); - } +public class LeadingAndTrailingOnesC extends BaseNd4jTestWithBackends { + @Test - public void testCreateLeadingAndTrailingOnes() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateLeadingAndTrailingOnes(Nd4jBackend backend) { INDArray arr = Nd4j.create(1, 10, 1, 1); arr.assign(1); // System.out.println(arr); } @Test - public void testMatrix() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatrix(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray slice1 = arr.slice(1); // System.out.println(arr.slice(1)); @@ -59,13 +61,15 @@ public class LeadingAndTrailingOnesC extends BaseNd4jTest { // System.out.println(otherSlice); INDArray twoOnesInMiddle = Nd4j.linspace(1, 4, 4).reshape(2, 1, 1, 2); INDArray sub = twoOnesInMiddle.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.all()); + NDArrayIndex.all()); assertEquals(2, sub.offset()); } @Test - public void testMultipleOnesInMiddle() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMultipleOnesInMiddle(Nd4jBackend backend) { INDArray tensor = Nd4j.linspace(1, 144, 144).reshape(2, 2, 1, 1, 6, 6); INDArray tensorSlice1 = tensor.slice(1); INDArray tensorSlice1Slice1 = tensorSlice1.slice(1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java index 144fc146f..ce184a659 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/reshape/ReshapeTests.java @@ -22,9 +22,10 @@ package org.nd4j.linalg.shape.reshape; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -38,15 +39,14 @@ import static org.junit.Assume.assumeNotNull; * @author Adam Gibson */ @Slf4j -@RunWith(Parameterized.class) -public class ReshapeTests extends BaseNd4jTest { - public ReshapeTests(Nd4jBackend backend) { - super(backend); - } +public class ReshapeTests extends BaseNd4jTestWithBackends { + @Test - public void testThreeTwoTwoTwo() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testThreeTwoTwoTwo(Nd4jBackend backend) { INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray sliceZero = Nd4j.create(new double[][] {{1, 7}, {4, 10}}); INDArray sliceOne = Nd4j.create(new double[][] {{2, 8}, {5, 11}}); @@ -67,7 +67,9 @@ public class ReshapeTests extends BaseNd4jTest { @Test - public void testColumnVectorReshape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testColumnVectorReshape(Nd4jBackend backend) { double delta = 1e-1; INDArray arr = Nd4j.create(1, 3); INDArray reshaped = arr.reshape('f', 3, 1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTests.java index 6f8d80828..a8faf7470 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTests.java @@ -22,27 +22,26 @@ package org.nd4j.linalg.slicing; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class SlicingTests extends BaseNd4jTest { - public SlicingTests(Nd4jBackend backend) { - super(backend); - } +public class SlicingTests extends BaseNd4jTestWithBackends { + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSlices() { INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24, DataType.DOUBLE).data(), new int[] {4, 3, 2}); for (int i = 0; i < arr.slices(); i++) { @@ -56,6 +55,8 @@ public class SlicingTests extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSlice() { INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2); INDArray assertion = Nd4j.create(new double[][] {{1, 13}, {5, 17}, {9, 21}}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java index b627ea3b0..b273d5196 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/slicing/SlicingTestsC.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.slicing; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -37,16 +38,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class SlicingTestsC extends BaseNd4jTest { - - public SlicingTestsC(Nd4jBackend backend) { - super(backend); - } +public class SlicingTestsC extends BaseNd4jTestWithBackends { + @Test - public void testSliceRowVector() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSliceRowVector(Nd4jBackend backend) { INDArray arr = Nd4j.zeros(5); // System.out.println(arr.slice(1)); arr.slice(1); @@ -54,7 +53,9 @@ public class SlicingTestsC extends BaseNd4jTest { } @Test - public void testSliceAssertion() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSliceAssertion(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 30, 30).reshape(3, 5, 2); INDArray firstRow = arr.slice(0).slice(0); // for (int i = 0; i < firstRow.length(); i++) { @@ -64,7 +65,9 @@ public class SlicingTestsC extends BaseNd4jTest { } @Test - public void testSliceShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSliceShape(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(3, 5, 2); INDArray sliceZero = arr.slice(0); @@ -93,7 +96,9 @@ public class SlicingTestsC extends BaseNd4jTest { } @Test - public void testSwapReshape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSwapReshape(Nd4jBackend backend) { INDArray n2 = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.FLOAT).data(), new int[] {3, 5, 2}); INDArray swapped = n2.swapAxes(n2.shape().length - 1, 1); INDArray firstSlice2 = swapped.slice(0).slice(0); @@ -114,7 +119,9 @@ public class SlicingTestsC extends BaseNd4jTest { @Test - public void testGetRow() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGetRow(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); INDArray get = arr.getRow(1); INDArray get2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all()); @@ -132,7 +139,9 @@ public class SlicingTestsC extends BaseNd4jTest { } @Test - public void testVectorIndexing() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorIndexing(Nd4jBackend backend) { INDArray zeros = Nd4j.create(1, 400000); INDArray get = zeros.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 300000)); assertArrayEquals(new long[] {300000}, get.shape()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java index 9347addcb..eedcd8fab 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java @@ -25,9 +25,10 @@ import lombok.val; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.factory.Nd4j; @@ -36,15 +37,11 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@RunWith(Parameterized.class) -public class CudaTests extends BaseNd4jTest { - DataType initialType; +public class CudaTests extends BaseNd4jTestWithBackends { + + DataType initialType = Nd4j.dataType(); - public CudaTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } @BeforeEach public void setUp() { @@ -57,7 +54,9 @@ public class CudaTests extends BaseNd4jTest { } @Test - public void testMGrid_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMGrid_1(Nd4jBackend backend) { if (!(Nd4j.getExecutioner() instanceof GridExecutioner)) return; @@ -78,7 +77,9 @@ public class CudaTests extends BaseNd4jTest { @Test - public void testMGrid_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMGrid_2(Nd4jBackend backend) { if (!(Nd4j.getExecutioner() instanceof GridExecutioner)) return; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/LongTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/LongTests.java index a9b1d8da7..85eae255f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/LongTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/LongTests.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.specials; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -43,18 +44,15 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals; @Slf4j @Disabled -@RunWith(Parameterized.class) -public class LongTests extends BaseNd4jTest { - DataType initialType; +public class LongTests extends BaseNd4jTestWithBackends { - public LongTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } + DataType initialType = Nd4j.dataType(); @Test - public void testSomething1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSomething1(Nd4jBackend backend) { // we create 2D array, total nr. of elements is 2.4B elements, > MAX_INT INDArray huge = Nd4j.create(8000000, 300); @@ -80,7 +78,9 @@ public class LongTests extends BaseNd4jTest { } @Test - public void testSomething2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSomething2(Nd4jBackend backend) { // we create 2D array, total nr. of elements is 2.4B elements, > MAX_INT INDArray huge = Nd4j.create(100, 10); @@ -106,7 +106,9 @@ public class LongTests extends BaseNd4jTest { } @Test - public void testLongTadOffsets1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLongTadOffsets1(Nd4jBackend backend) { INDArray huge = Nd4j.create(230000000, 10); Pair tad = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(huge, 1); @@ -115,7 +117,9 @@ public class LongTests extends BaseNd4jTest { } @Test - public void testLongTadOp1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLongTadOp1(Nd4jBackend backend) { double exp = Transforms.manhattanDistance(Nd4j.create(1000).assign(1.0), Nd4j.create(1000).assign(2.0)); @@ -133,7 +137,9 @@ public class LongTests extends BaseNd4jTest { } @Test - public void testLongTadOp2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLongTadOp2(Nd4jBackend backend) { INDArray hugeX = Nd4j.create(2300000, 1000).assign(1.0); hugeX.addiRowVector(Nd4j.create(1000).assign(2.0)); @@ -144,7 +150,9 @@ public class LongTests extends BaseNd4jTest { } @Test - public void testLongTadOp2_micro() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLongTadOp2_micro(Nd4jBackend backend) { INDArray hugeX = Nd4j.create(230, 1000).assign(1.0); hugeX.addiRowVector(Nd4j.create(1000).assign(2.0)); @@ -155,7 +163,9 @@ public class LongTests extends BaseNd4jTest { } @Test - public void testLongTadOp3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLongTadOp3(Nd4jBackend backend) { INDArray hugeX = Nd4j.create(2300000, 1000).assign(1.0); INDArray mean = hugeX.mean(1); @@ -166,7 +176,9 @@ public class LongTests extends BaseNd4jTest { } @Test - public void testLongTadOp4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLongTadOp4(Nd4jBackend backend) { INDArray hugeX = Nd4j.create(2300000, 1000).assign(1.0); INDArray mean = hugeX.argMax(1); @@ -177,7 +189,9 @@ public class LongTests extends BaseNd4jTest { } @Test - public void testLongTadOp5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLongTadOp5(Nd4jBackend backend) { List list = new ArrayList<>(); for (int i = 0; i < 2300000; i++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java index 23bdfb376..e59d81d6f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java @@ -26,9 +26,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.Assert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; @@ -37,23 +38,19 @@ import org.nd4j.nativeblas.NativeOpsHolder; @Slf4j -@RunWith(Parameterized.class) -public class RavelIndexTest extends BaseNd4jTest { - DataType initialType; +public class RavelIndexTest extends BaseNd4jTestWithBackends { + + DataType initialType = Nd4j.dataType(); - public RavelIndexTest(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } @BeforeEach - public void setUp() { + public void setUp(Nd4jBackend backend) { Nd4j.setDataType(DataType.FLOAT); } @AfterEach - public void setDown() { + public void setDown(Nd4jBackend backend) { Nd4j.setDataType(initialType); } @@ -64,60 +61,62 @@ public class RavelIndexTest extends BaseNd4jTest { @Test - public void ravelIndexesTest() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void ravelIndexesTest(Nd4jBackend backend) { // FIXME: we don't want this test running on cuda for now if (Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda")) return; long[] multiIdxArray = new long[] { - 0,2,7, - 2,36,35, - 3,30,17, - 5,12,22, - 5,43,45, - 6,32,11, - 8,8,32, - 9,29,11, - 5,11,22, - 15,26,16, - 17,48,49, - 24,28,31, - 26,6,23, - 31,21,31, - 35,46,45, - 37,13,14, - 6,38,18, - 7,28,20, - 8,29,39, - 8,32,30, - 9,42,43, - 11,15,18, - 13,18,45, - 29,26,39, - 30,8,25, - 42,31,24, - 28,33,5, - 31,27,1, - 35,43,26, - 36,8,37, - 39,22,14, - 39,24,42, - 42,48,2, - 43,26,48, - 44,23,49, - 45,18,34, - 46,28,5, - 46,32,17, - 48,34,44, - 49,38,39, + 0,2,7, + 2,36,35, + 3,30,17, + 5,12,22, + 5,43,45, + 6,32,11, + 8,8,32, + 9,29,11, + 5,11,22, + 15,26,16, + 17,48,49, + 24,28,31, + 26,6,23, + 31,21,31, + 35,46,45, + 37,13,14, + 6,38,18, + 7,28,20, + 8,29,39, + 8,32,30, + 9,42,43, + 11,15,18, + 13,18,45, + 29,26,39, + 30,8,25, + 42,31,24, + 28,33,5, + 31,27,1, + 35,43,26, + 36,8,37, + 39,22,14, + 39,24,42, + 42,48,2, + 43,26,48, + 44,23,49, + 45,18,34, + 46,28,5, + 46,32,17, + 48,34,44, + 49,38,39, }; long[] flatIdxArray = new long[] { - 147, 10955, 14717, 21862, 24055, 27451, 34192, 39841, - 21792, 64836, 74809, 102791, 109643, 131701, 150265, 156324, - 27878, 31380, 35669, 35870, 40783, 47268, 55905, 123659, - 126585, 178594, 119915, 132091, 150036, 151797, 165354, 165522, - 179762, 182468, 186459, 190294, 195165, 195457, 204024, 208499 + 147, 10955, 14717, 21862, 24055, 27451, 34192, 39841, + 21792, 64836, 74809, 102791, 109643, 131701, 150265, 156324, + 27878, 31380, 35669, 35870, 40783, 47268, 55905, 123659, + 126585, 178594, 119915, 132091, 150036, 151797, 165354, 165522, + 179762, 182468, 186459, 190294, 195165, 195457, 204024, 208499 }; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java index 1941ac47a..eef65331a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java @@ -20,22 +20,20 @@ package org.nd4j.linalg.specials; -import com.google.common.primitives.Doubles; -import com.google.common.primitives.Floats; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.LongPointer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; -import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.nativeblas.NativeOpsHolder; @@ -46,30 +44,28 @@ import java.util.stream.LongStream; import static org.junit.jupiter.api.Assertions.assertArrayEquals; @Slf4j -@RunWith(Parameterized.class) -public class SortCooTests extends BaseNd4jTest { - DataType initialType; - DataType initialDefaultType; +public class SortCooTests extends BaseNd4jTestWithBackends { + + DataType initialType = Nd4j.dataType(); + DataType initialDefaultType = Nd4j.defaultFloatingPointType(); + - public SortCooTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - this.initialDefaultType = Nd4j.defaultFloatingPointType(); - } @BeforeEach - public void setUp() { + public void setUp(Nd4jBackend backend) { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); } @AfterEach - public void setDown() { + public void setDown(Nd4jBackend backend) { Nd4j.setDefaultDataTypes(initialType, Nd4j.defaultFloatingPointType()); } @Test - public void sortSparseCooIndicesSort1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void sortSparseCooIndicesSort1(Nd4jBackend backend) { // FIXME: we don't want this test running on cuda for now if (Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda")) return; @@ -103,7 +99,9 @@ public class SortCooTests extends BaseNd4jTest { } @Test - public void sortSparseCooIndicesSort2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void sortSparseCooIndicesSort2(Nd4jBackend backend) { // FIXME: we don't want this test running on cuda for now if (Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda")) return; @@ -150,7 +148,9 @@ public class SortCooTests extends BaseNd4jTest { } @Test - public void sortSparseCooIndicesSort3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void sortSparseCooIndicesSort3(Nd4jBackend backend) { // FIXME: we don't want this test running on cuda for now if (Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda")) return; @@ -188,7 +188,9 @@ public class SortCooTests extends BaseNd4jTest { } @Test - public void sortSparseCooIndicesSort4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void sortSparseCooIndicesSort4(Nd4jBackend backend) { // FIXME: we don't want this test running on cuda for now if (Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda")) return; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java index b77633083..04569daf0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java @@ -26,7 +26,9 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -38,11 +40,8 @@ import java.nio.file.Path; import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j -public class DataSetUtilsTest extends BaseNd4jTest { +public class DataSetUtilsTest extends BaseNd4jTestWithBackends { - public DataSetUtilsTest(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -55,7 +54,9 @@ public class DataSetUtilsTest extends BaseNd4jTest { private SIS sis; // @Test - public void testAll(@TempDir Path tmpFld) { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAll(@TempDir Path tmpFld,Nd4jBackend backend) { // sis = new SIS(); // diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java index 7e784853f..be46fa226 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/NDArrayUtilTest.java @@ -21,10 +21,11 @@ package org.nd4j.linalg.util; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.common.util.ArrayUtil; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4jBackend; @@ -34,22 +35,23 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * @author Hamdi Douss */ -@RunWith(Parameterized.class) -public class NDArrayUtilTest extends BaseNd4jTest { - public NDArrayUtilTest(Nd4jBackend backend) { - super(backend); - } +public class NDArrayUtilTest extends BaseNd4jTestWithBackends { + @Test - public void testMatrixConversion() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMatrixConversion(Nd4jBackend backend) { int[][] nums = {{1, 2}, {3, 4}, {5, 6}}; INDArray result = NDArrayUtil.toNDArray(nums); assertArrayEquals(new long[]{2,3}, result.shape()); } @Test - public void testVectorConversion() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorConversion(Nd4jBackend backend) { int[] nums = {1, 2, 3, 4}; INDArray result = NDArrayUtil.toNDArray(nums); assertArrayEquals(new long[]{1, 4}, result.shape()); @@ -57,7 +59,9 @@ public class NDArrayUtilTest extends BaseNd4jTest { @Test - public void testFlattenArray1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlattenArray1(Nd4jBackend backend) { float[][][] arrX = new float[2][2][2]; float[] arrZ = ArrayUtil.flatten(arrX); @@ -66,7 +70,9 @@ public class NDArrayUtilTest extends BaseNd4jTest { } @Test - public void testFlattenArray2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlattenArray2(Nd4jBackend backend) { float[][][] arrX = new float[5][4][3]; float[] arrZ = ArrayUtil.flatten(arrX); @@ -76,7 +82,9 @@ public class NDArrayUtilTest extends BaseNd4jTest { @Test - public void testFlattenArray3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlattenArray3(Nd4jBackend backend) { float[][][] arrX = new float[5][2][3]; float[] arrZ = ArrayUtil.flatten(arrX); @@ -85,7 +93,9 @@ public class NDArrayUtilTest extends BaseNd4jTest { } @Test - public void testFlattenArray4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlattenArray4(Nd4jBackend backend) { float[][][][] arrX = new float[5][2][3][3]; float[] arrZ = ArrayUtil.flatten(arrX); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/PreconditionsTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/PreconditionsTest.java index b6f06e668..0922cb9e2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/PreconditionsTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/PreconditionsTest.java @@ -21,8 +21,10 @@ package org.nd4j.linalg.util; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.common.base.Preconditions; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -33,14 +35,12 @@ import java.util.Arrays; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; -public class PreconditionsTest extends BaseNd4jTest { - - public PreconditionsTest(Nd4jBackend backend) { - super(backend); - } +public class PreconditionsTest extends BaseNd4jTestWithBackends { @Test - public void test(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void test(Nd4jBackend backend){ INDArray arr = Nd4j.linspace(1,60,60).reshape('c',3,4,5); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTest.java index 6f1d088be..6162e05e5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTest.java @@ -21,9 +21,10 @@ package org.nd4j.linalg.util; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; @@ -34,16 +35,15 @@ import static org.junit.jupiter.api.Assertions.*; /** * @author Adam Gibson */ -@RunWith(Parameterized.class) -public class ShapeTest extends BaseNd4jTest { - public ShapeTest(Nd4jBackend backend) { - super(backend); - } +public class ShapeTest extends BaseNd4jTestWithBackends { + @Test - public void testToOffsetZero() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToOffsetZero(Nd4jBackend backend) { INDArray matrix = Nd4j.rand(3, 5); INDArray rowOne = matrix.getRow(1); INDArray row1Copy = Shape.toOffsetZero(rowOne); @@ -63,7 +63,9 @@ public class ShapeTest extends BaseNd4jTest { @Test - public void testDupLeadingTrailingZeros() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDupLeadingTrailingZeros(Nd4jBackend backend) { testDupHelper(1, 10); testDupHelper(10, 1); testDupHelper(1, 10, 1); @@ -84,7 +86,9 @@ public class ShapeTest extends BaseNd4jTest { } @Test - public void testLeadingOnes() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLeadingOnes(Nd4jBackend backend) { INDArray arr = Nd4j.create(1, 5, 5); assertEquals(1, arr.getLeadingOnes()); INDArray arr2 = Nd4j.create(2, 2); @@ -94,7 +98,9 @@ public class ShapeTest extends BaseNd4jTest { } @Test - public void testTrailingOnes() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTrailingOnes(Nd4jBackend backend) { INDArray arr2 = Nd4j.create(5, 5, 1); assertEquals(1, arr2.getTrailingOnes()); INDArray arr4 = Nd4j.create(5, 5, 1, 1); @@ -102,7 +108,9 @@ public class ShapeTest extends BaseNd4jTest { } @Test - public void testElementWiseCompareOnesInMiddle() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testElementWiseCompareOnesInMiddle(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3); INDArray onesInMiddle = Nd4j.linspace(1, 6, 6).reshape(2, 1, 3); for (int i = 0; i < arr.length(); i++) { @@ -114,7 +122,9 @@ public class ShapeTest extends BaseNd4jTest { @Test - public void testSumLeadingTrailingZeros() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSumLeadingTrailingZeros(Nd4jBackend backend) { testSumHelper(1, 5, 5); testSumHelper(5, 5, 1); testSumHelper(1, 5, 1); @@ -144,6 +154,8 @@ public class ShapeTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testEqualsWithSqueeze(){ assertTrue(Shape.shapeEqualWithSqueeze(null, null)); @@ -165,6 +177,8 @@ public class ShapeTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testShapeOrder(){ long[] shape = {2,2}; long[] stride = {1,8}; //Ascending strides -> F order diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java index 419b8d015..67435acf1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java @@ -23,9 +23,10 @@ package org.nd4j.linalg.util; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.shape.Tile; @@ -40,16 +41,15 @@ import static org.junit.jupiter.api.Assertions.*; * @author Adam Gibson */ @Slf4j -@RunWith(Parameterized.class) -public class ShapeTestC extends BaseNd4jTest { - public ShapeTestC(Nd4jBackend backend) { - super(backend); - } +public class ShapeTestC extends BaseNd4jTestWithBackends { + @Test - public void testToOffsetZero() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToOffsetZero(Nd4jBackend backend) { INDArray matrix = Nd4j.rand(3, 5); INDArray rowOne = matrix.getRow(1); INDArray row1Copy = Shape.toOffsetZero(rowOne); @@ -68,7 +68,9 @@ public class ShapeTestC extends BaseNd4jTest { @Test - public void testTile() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testTile(Nd4jBackend backend) { INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0).reshape(1, 1); //INDArray[] inputs, INDArray[] outputs, int[] axis INDArray result = Nd4j.createUninitialized(DataType.DOUBLE, 2,2); @@ -80,7 +82,9 @@ public class ShapeTestC extends BaseNd4jTest { } @Test - public void testElementWiseCompareOnesInMiddle() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testElementWiseCompareOnesInMiddle(Nd4jBackend backend) { INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3); INDArray onesInMiddle = Nd4j.linspace(1, 6, 6).reshape(2, 1, 3); for (int i = 0; i < arr.length(); i++) @@ -89,7 +93,9 @@ public class ShapeTestC extends BaseNd4jTest { @Test - public void testKeepDimsShape_1_T() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testKeepDimsShape_1_T(Nd4jBackend backend) { val shape = new int[]{5, 5}; val axis = new int[]{1, 0, 1}; @@ -99,7 +105,9 @@ public class ShapeTestC extends BaseNd4jTest { } @Test - public void testKeepDimsShape_1_F() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testKeepDimsShape_1_F(Nd4jBackend backend) { val shape = new int[]{5, 5}; val axis = new int[]{0, 0, 1}; @@ -109,7 +117,9 @@ public class ShapeTestC extends BaseNd4jTest { } @Test - public void testKeepDimsShape_2_T() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testKeepDimsShape_2_T(Nd4jBackend backend) { val shape = new int[]{5, 5, 5}; val axis = new int[]{1, 0, 1}; @@ -119,7 +129,9 @@ public class ShapeTestC extends BaseNd4jTest { } @Test - public void testKeepDimsShape_2_F() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testKeepDimsShape_2_F(Nd4jBackend backend) { val shape = new int[]{5, 5, 5}; val axis = new int[]{0, 0, 1}; @@ -130,7 +142,9 @@ public class ShapeTestC extends BaseNd4jTest { @Test - public void testKeepDimsShape_3_T() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testKeepDimsShape_3_T(Nd4jBackend backend) { val shape = new int[]{1, 1}; val axis = new int[]{1, 0, 1}; @@ -140,7 +154,9 @@ public class ShapeTestC extends BaseNd4jTest { } @Test - public void testKeepDimsShape_3_F() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testKeepDimsShape_3_F(Nd4jBackend backend) { val shape = new int[]{1, 1}; val axis = new int[]{0, 0}; @@ -153,7 +169,9 @@ public class ShapeTestC extends BaseNd4jTest { @Test - public void testKeepDimsShape_4_F() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testKeepDimsShape_4_F(Nd4jBackend backend) { val shape = new int[]{4, 4}; val axis = new int[]{0, 0}; @@ -166,7 +184,9 @@ public class ShapeTestC extends BaseNd4jTest { @Test - public void testAxisNormalization_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAxisNormalization_1(Nd4jBackend backend) { val axis = new int[] {1, -2}; val rank = 2; val exp = new int[] {0, 1}; @@ -176,7 +196,9 @@ public class ShapeTestC extends BaseNd4jTest { } @Test - public void testAxisNormalization_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAxisNormalization_2(Nd4jBackend backend) { val axis = new int[] {1, -2, 0}; val rank = 2; val exp = new int[] {0, 1}; @@ -186,20 +208,22 @@ public class ShapeTestC extends BaseNd4jTest { } @Test() - public void testAxisNormalization_3() { - assertThrows(ND4JIllegalStateException.class,() -> { - val axis = new int[] {1, -2, 2}; - val rank = 2; - val exp = new int[] {0, 1}; + public void testAxisNormalization_3(Nd4jBackend backend) { + assertThrows(ND4JIllegalStateException.class,() -> { + val axis = new int[] {1, -2, 2}; + val rank = 2; + val exp = new int[] {0, 1}; - val norm = Shape.normalizeAxis(rank, axis); - assertArrayEquals(exp, norm); - }); + val norm = Shape.normalizeAxis(rank, axis); + assertArrayEquals(exp, norm); + }); } @Test - public void testAxisNormalization_4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAxisNormalization_4(Nd4jBackend backend) { val axis = new int[] {1, 2, 0}; val rank = 3; val exp = new int[] {0, 1, 2}; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestArrayUtils.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestArrayUtils.java index 64e235af5..4bc48ab11 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestArrayUtils.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestArrayUtils.java @@ -21,22 +21,23 @@ package org.nd4j.linalg.util; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.common.util.ArrayUtil; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.factory.Nd4jBackend; import java.util.Random; import static org.junit.jupiter.api.Assertions.*; -public class TestArrayUtils extends BaseNd4jTest { +public class TestArrayUtils extends BaseNd4jTestWithBackends { - public TestArrayUtils(Nd4jBackend backend) { - super(backend); - } @Test - public void testFlattenDoubleArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlattenDoubleArray(Nd4jBackend backend) { assertArrayEquals(new double[0], ArrayUtil.flattenDoubleArray(new double[0]), 0.0); Random r = new Random(12345L); @@ -84,7 +85,9 @@ public class TestArrayUtils extends BaseNd4jTest { } @Test - public void testFlattenFloatArray() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFlattenFloatArray(Nd4jBackend backend) { assertArrayEquals(new float[0], ArrayUtil.flattenFloatArray(new float[0]), 0.0f); Random r = new Random(12345L); @@ -132,7 +135,9 @@ public class TestArrayUtils extends BaseNd4jTest { } @Test - public void testArrayShape() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArrayShape(Nd4jBackend backend) { assertArrayEquals(ArrayUtil.arrayShape(new int[0]), new int[] {0}); assertArrayEquals(ArrayUtil.arrayShape(new int[5][7][9]), new int[] {5, 7, 9}); assertArrayEquals(ArrayUtil.arrayShape(new Object[2][3][4][5][6]), new int[] {2, 3, 4, 5, 6}); @@ -143,7 +148,9 @@ public class TestArrayUtils extends BaseNd4jTest { } @Test - public void testArgMinOfMaxMethods() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testArgMinOfMaxMethods(Nd4jBackend backend) { int[] first = {1, 5, 2, 4}; int[] second = {4, 6, 3, 2}; @@ -154,7 +161,9 @@ public class TestArrayUtils extends BaseNd4jTest { } @Test - public void testAssertNotRagged(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssertNotRagged(Nd4jBackend backend){ //Rank 1 - should be fine ArrayUtil.assertNotRagged(new Object[0]); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestCollections.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestCollections.java index 1d153d512..9a8334527 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestCollections.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/TestCollections.java @@ -21,7 +21,9 @@ package org.nd4j.linalg.util; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.common.collection.CompactHeapStringList; import org.nd4j.linalg.factory.Nd4jBackend; @@ -30,14 +32,12 @@ import java.util.*; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -public class TestCollections extends BaseNd4jTest { - - public TestCollections(Nd4jBackend backend) { - super(backend); - } +public class TestCollections extends BaseNd4jTestWithBackends { @Test - public void testCompactHeapStringList() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCompactHeapStringList(Nd4jBackend backend) { int[] reallocSizeBytes = new int[] {1024, 1048576}; int[] intReallocSizeBytes = new int[] {1024, 1048576}; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java index 01a363b10..cd19f1793 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java @@ -26,9 +26,11 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -47,15 +49,12 @@ import java.util.zip.ZipOutputStream; import static org.junit.jupiter.api.Assertions.*; -public class ValidationUtilTests extends BaseNd4jTest { - - - public ValidationUtilTests(Nd4jBackend backend) { - super(backend); - } +public class ValidationUtilTests extends BaseNd4jTestWithBackends { @Test - public void testFileValidation(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testFileValidation(@TempDir Path testDir,Nd4jBackend backend) throws Exception { File f = testDir.toFile(); //Test not existent file: @@ -91,7 +90,9 @@ public class ValidationUtilTests extends BaseNd4jTest { } @Test - public void testZipValidation(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testZipValidation(@TempDir Path testDir,Nd4jBackend backend) throws Exception { File f = testDir.toFile(); //Test not existent file: @@ -141,7 +142,9 @@ public class ValidationUtilTests extends BaseNd4jTest { @Test - public void testINDArrayTextValidation(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testINDArrayTextValidation(@TempDir Path testDir,Nd4jBackend backend) throws Exception { File f = testDir.toFile(); //Test not existent file: @@ -282,7 +285,9 @@ public class ValidationUtilTests extends BaseNd4jTest { } @Test - public void testNpzValidation(@TempDir Path testDIr) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNpzValidation(@TempDir Path testDIr,Nd4jBackend backend) throws Exception { File f = testDIr.toFile(); @@ -351,7 +356,9 @@ public class ValidationUtilTests extends BaseNd4jTest { } @Test - public void testNumpyTxtValidation(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNumpyTxtValidation(@TempDir Path testDir,Nd4jBackend backend) throws Exception { File f = testDir.toFile(); //Test not existent file: @@ -419,7 +426,9 @@ public class ValidationUtilTests extends BaseNd4jTest { } @Test - public void testValidateSameDiff(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testValidateSameDiff(@TempDir Path testDir,Nd4jBackend backend) throws Exception { Nd4j.setDataType(DataType.FLOAT); File f = testDir.toFile(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java index f70c753fc..3adc87262 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java @@ -26,9 +26,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -52,9 +53,9 @@ import static org.junit.jupiter.api.Assertions.*; import static org.nd4j.linalg.api.buffer.DataType.DOUBLE; @Slf4j -@RunWith(Parameterized.class) -public class BasicWorkspaceTests extends BaseNd4jTest { - DataType initialType; + +public class BasicWorkspaceTests extends BaseNd4jTestWithBackends { + DataType initialType = Nd4j.dataType(); private static final WorkspaceConfiguration basicConfig = WorkspaceConfiguration.builder() .initialSize(10 * 1024 * 1024).maxSize(10 * 1024 * 1024).overallocationLimit(0.1) @@ -72,10 +73,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP) .policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.EXTERNAL).build(); - public BasicWorkspaceTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } + @BeforeEach public void setUp() { @@ -91,7 +89,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testCold() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCold(Nd4jBackend backend) { INDArray array = Nd4j.create(10); array.addi(1.0); @@ -100,7 +100,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testMinSize1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMinSize1(Nd4jBackend backend) { WorkspaceConfiguration conf = WorkspaceConfiguration.builder().minSize(10 * 1024 * 1024) .overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE) .policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL) @@ -120,7 +122,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testBreakout2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBreakout2(Nd4jBackend backend) { assertEquals(null, Nd4j.getMemoryManager().getCurrentWorkspace()); @@ -132,7 +136,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testBreakout1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBreakout1(Nd4jBackend backend) { assertEquals(null, Nd4j.getMemoryManager().getCurrentWorkspace()); @@ -162,7 +168,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testLeverage3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLeverage3(Nd4jBackend backend) { try (Nd4jWorkspace wsOne = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "EXT")) { INDArray array = null; @@ -183,7 +191,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { @Test - public void testLeverageTo2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLeverageTo2(Nd4jBackend backend) { val exp = Nd4j.scalar(15.0); try (Nd4jWorkspace wsOne = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(loopOverTimeConfig, "EXT")) { @@ -217,7 +227,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testLeverageTo1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLeverageTo1(Nd4jBackend backend) { try (Nd4jWorkspace wsOne = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "EXT")) { INDArray array1 = Nd4j.create(DOUBLE, 5); @@ -237,7 +249,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testOutOfScope1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOutOfScope1(Nd4jBackend backend) { try (Nd4jWorkspace wsOne = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "EXT")) { INDArray array1 = Nd4j.create(new double[] {1f, 2f, 3f, 4f, 5f}); @@ -267,7 +281,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testLeverage1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLeverage1(Nd4jBackend backend) { try (Nd4jWorkspace wsOne = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "EXT")) { @@ -298,7 +314,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testNoShape1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNoShape1(Nd4jBackend backend) { int outDepth = 50; int miniBatch = 64; int outH = 8; @@ -319,7 +337,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testCreateDetached1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCreateDetached1(Nd4jBackend backend) { try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "ITER")) { @@ -342,7 +362,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { @Test - public void testDetach1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDetach1(Nd4jBackend backend) { INDArray array = null; INDArray copy = null; try (Nd4jWorkspace wsI = @@ -372,7 +394,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testScope2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScope2(Nd4jBackend backend) { INDArray array = null; try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(loopFirstConfig, "ITER")) { @@ -396,7 +420,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testScope1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScope1(Nd4jBackend backend) { INDArray array = null; try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "ITER")) { @@ -409,7 +435,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testIsAttached3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsAttached3(Nd4jBackend backend) { INDArray array = Nd4j.create(DOUBLE, 100); try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "ITER")) { @@ -427,7 +455,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testIsAttached2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsAttached2(Nd4jBackend backend) { INDArray array = Nd4j.create(DOUBLE, 100); try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(loopFirstConfig, "ITER")) { @@ -444,7 +474,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testIsAttached1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testIsAttached1(Nd4jBackend backend) { try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(loopFirstConfig, "ITER")) { @@ -459,7 +491,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testOverallocation3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOverallocation3(Nd4jBackend backend) { WorkspaceConfiguration overallocationConfig = WorkspaceConfiguration.builder().initialSize(0) .maxSize(10 * 1024 * 1024).overallocationLimit(1.0) .policyAllocation(AllocationPolicy.OVERALLOCATE).policyLearning(LearningPolicy.OVER_TIME) @@ -487,7 +521,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testOverallocation2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOverallocation2(Nd4jBackend backend) { WorkspaceConfiguration overallocationConfig = WorkspaceConfiguration.builder().initialSize(0) .maxSize(10 * 1024 * 1024).overallocationLimit(1.0) .policyAllocation(AllocationPolicy.OVERALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP) @@ -508,7 +544,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testOverallocation1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testOverallocation1(Nd4jBackend backend) { WorkspaceConfiguration overallocationConfig = WorkspaceConfiguration.builder().initialSize(1024) .maxSize(10 * 1024 * 1024).overallocationLimit(1.0) .policyAllocation(AllocationPolicy.OVERALLOCATE).policyLearning(LearningPolicy.NONE) @@ -520,7 +558,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testToggle1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToggle1(Nd4jBackend backend) { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().createNewWorkspace(loopFirstConfig); Nd4j.getMemoryManager().setCurrentWorkspace(workspace); @@ -574,7 +614,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testLoop4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLoop4(Nd4jBackend backend) { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().createNewWorkspace(loopFirstConfig); Nd4j.getMemoryManager().setCurrentWorkspace(workspace); @@ -601,7 +643,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testLoops3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLoops3(Nd4jBackend backend) { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().createNewWorkspace(loopFirstConfig); Nd4j.getMemoryManager().setCurrentWorkspace(workspace); @@ -628,7 +672,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testLoops2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLoops2(Nd4jBackend backend) { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().createNewWorkspace(loopOverTimeConfig); Nd4j.getMemoryManager().setCurrentWorkspace(workspace); @@ -666,7 +712,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testLoops1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testLoops1(Nd4jBackend backend) { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().createNewWorkspace(loopOverTimeConfig); Nd4j.getMemoryManager().setCurrentWorkspace(workspace); @@ -721,7 +769,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testAllocation6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllocation6(Nd4jBackend backend) { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "testAllocation6"); Nd4j.getMemoryManager().setCurrentWorkspace(workspace); @@ -745,7 +795,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testAllocation5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllocation5(Nd4jBackend backend) { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "testAllocation5"); Nd4j.getMemoryManager().setCurrentWorkspace(workspace); @@ -773,7 +825,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { @Test - public void testAllocation4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllocation4(Nd4jBackend backend) { WorkspaceConfiguration failConfig = WorkspaceConfiguration.builder().initialSize(1024 * 1024) .maxSize(1024 * 1024).overallocationLimit(0.1).policyAllocation(AllocationPolicy.STRICT) .policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL) @@ -809,7 +863,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testAllocation3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllocation3(Nd4jBackend backend) { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "testAllocation2"); @@ -833,7 +889,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testAllocation2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllocation2(Nd4jBackend backend) { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "testAllocation2"); @@ -857,7 +915,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testAllocation1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAllocation1(Nd4jBackend backend) { @@ -929,7 +989,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { @Test - public void testMmap1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmap1(Nd4jBackend backend) { // we don't support MMAP on cuda yet if (Nd4j.getExecutioner().getClass().getName().toLowerCase().contains("cuda")) return; @@ -961,7 +1023,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { @Test @Disabled - public void testMmap2() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMmap2(Nd4jBackend backend) throws Exception { // we don't support MMAP on cuda yet if (Nd4j.getExecutioner().getClass().getName().toLowerCase().contains("cuda")) return; @@ -987,7 +1051,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { @Test - public void testInvalidLeverageMigrateDetach(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testInvalidLeverageMigrateDetach(Nd4jBackend backend){ try { MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(basicConfig, "testInvalidLeverage"); @@ -1093,7 +1159,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testBadGenerationLeverageMigrateDetach(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBadGenerationLeverageMigrateDetach(Nd4jBackend backend){ INDArray gen2 = null; for (int i = 0; i < 4; i++) { @@ -1198,7 +1266,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testDtypeLeverage(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDtypeLeverage(Nd4jBackend backend){ for(DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { for (DataType arrayDType : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { @@ -1227,7 +1297,9 @@ public class BasicWorkspaceTests extends BaseNd4jTest { } @Test - public void testCircularWorkspaceAsymmetry_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCircularWorkspaceAsymmetry_1(Nd4jBackend backend) { // nothing to test on CPU here if (Nd4j.getEnvironment().isCPU()) return; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CudaWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CudaWorkspaceTests.java index c10115122..aac547e9d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CudaWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CudaWorkspaceTests.java @@ -23,31 +23,29 @@ package org.nd4j.linalg.workspace; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.MirroringPolicy; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@RunWith(Parameterized.class) -public class CudaWorkspaceTests extends BaseNd4jTest { - private DataType initialType; - public CudaWorkspaceTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } +public class CudaWorkspaceTests extends BaseNd4jTestWithBackends { + private DataType initialType = Nd4j.dataType(); + @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWorkspaceReuse() { if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA) return; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java index 3aaf5b23b..9f1cb93ba 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/CyclicWorkspaceTests.java @@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.AllocationPolicy; import org.nd4j.linalg.api.memory.enums.LearningPolicy; @@ -36,14 +37,13 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @Slf4j -@RunWith(Parameterized.class) -public class CyclicWorkspaceTests extends BaseNd4jTest { - public CyclicWorkspaceTests(Nd4jBackend backend) { - super(backend); - } + +public class CyclicWorkspaceTests extends BaseNd4jTestWithBackends { @Test - public void testBasicMechanics_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBasicMechanics_1(Nd4jBackend backend) { val fShape = new long[]{128, 784}; val lShape = new long[] {128, 10}; val prefetchSize = 24; @@ -64,7 +64,9 @@ public class CyclicWorkspaceTests extends BaseNd4jTest { @Test @Disabled - public void testGc() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testGc(Nd4jBackend backend) { val indArray = Nd4j.create(4, 4); indArray.putRow(0, Nd4j.create(new float[]{0, 2, -2, 0})); indArray.putRow(1, Nd4j.create(new float[]{0, 1, -1, 0})); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java index 2b18ead2d..a990069ce 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java @@ -25,9 +25,10 @@ import lombok.val; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.AllocationPolicy; @@ -42,14 +43,11 @@ import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class DebugModeTests extends BaseNd4jTest { - DataType initialType; - public DebugModeTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } +public class DebugModeTests extends BaseNd4jTestWithBackends { + DataType initialType = Nd4j.dataType(); + + @BeforeEach public void turnMeUp() { @@ -69,7 +67,9 @@ public class DebugModeTests extends BaseNd4jTest { } @Test - public void testDebugMode_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDebugMode_1(Nd4jBackend backend) { assertEquals(DebugMode.DISABLED, Nd4j.getWorkspaceManager().getDebugMode()); Nd4j.getWorkspaceManager().setDebugMode(DebugMode.SPILL_EVERYTHING); @@ -78,7 +78,9 @@ public class DebugModeTests extends BaseNd4jTest { } @Test - public void testSpillMode_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSpillMode_1(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDebugMode(DebugMode.SPILL_EVERYTHING); val basicConfig = WorkspaceConfiguration.builder() @@ -104,7 +106,9 @@ public class DebugModeTests extends BaseNd4jTest { } @Test - public void testSpillMode_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSpillMode_2(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDebugMode(DebugMode.SPILL_EVERYTHING); val basicConfig = WorkspaceConfiguration.builder() @@ -138,7 +142,9 @@ public class DebugModeTests extends BaseNd4jTest { } @Test - public void testBypassMode_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBypassMode_1(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDebugMode(DebugMode.BYPASS_EVERYTHING); val basicConfig = WorkspaceConfiguration.builder() diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/EndlessWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/EndlessWorkspaceTests.java index cce4562ca..c65c28e43 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/EndlessWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/EndlessWorkspaceTests.java @@ -27,9 +27,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; @@ -49,14 +50,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals; @Disabled @Slf4j -@RunWith(Parameterized.class) -public class EndlessWorkspaceTests extends BaseNd4jTest { - DataType initialType; - public EndlessWorkspaceTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } +public class EndlessWorkspaceTests extends BaseNd4jTestWithBackends { + DataType initialType = Nd4j.dataType(); @BeforeEach public void startUp() { @@ -77,7 +73,9 @@ public class EndlessWorkspaceTests extends BaseNd4jTest { * @throws Exception */ @Test - public void endlessTest1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void endlessTest1(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration( WorkspaceConfiguration.builder().initialSize(100 * 1024L * 1024L).build()); @@ -104,7 +102,9 @@ public class EndlessWorkspaceTests extends BaseNd4jTest { * @throws Exception */ @Test - public void endlessTest2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void endlessTest2(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration( WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L).build()); @@ -138,7 +138,9 @@ public class EndlessWorkspaceTests extends BaseNd4jTest { * @throws Exception */ @Test - public void endlessTest3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void endlessTest3(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration( WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L).build()); @@ -167,7 +169,9 @@ public class EndlessWorkspaceTests extends BaseNd4jTest { } @Test - public void endlessTest4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void endlessTest4(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration( WorkspaceConfiguration.builder().initialSize(100 * 1024L * 1024L).build()); while (true) { @@ -188,7 +192,9 @@ public class EndlessWorkspaceTests extends BaseNd4jTest { } @Test - public void endlessTest5() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void endlessTest5(Nd4jBackend backend) throws Exception { while (true) { Thread thread = new Thread(new Runnable() { @Override @@ -210,7 +216,9 @@ public class EndlessWorkspaceTests extends BaseNd4jTest { } @Test - public void endlessTest6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void endlessTest6(Nd4jBackend backend) { Nd4j.getMemoryManager().togglePeriodicGc(false); WorkspaceConfiguration wsConf = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) .policyLearning(LearningPolicy.NONE).build(); @@ -227,7 +235,10 @@ public class EndlessWorkspaceTests extends BaseNd4jTest { } @Test - public void endlessValidation1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + + public void endlessValidation1(Nd4jBackend backend) { Nd4j.getMemoryManager().togglePeriodicGc(true); AtomicLong counter = new AtomicLong(0); @@ -246,7 +257,9 @@ public class EndlessWorkspaceTests extends BaseNd4jTest { @Test - public void testPerf1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPerf1(Nd4jBackend backend) { Nd4j.getWorkspaceManager() .setDefaultWorkspaceConfiguration(WorkspaceConfiguration.builder().initialSize(50000L).build()); @@ -287,7 +300,9 @@ public class EndlessWorkspaceTests extends BaseNd4jTest { } @Test - public void endlessTestSerDe1() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void endlessTestSerDe1(Nd4jBackend backend) throws Exception { INDArray features = Nd4j.create(32, 3, 224, 224); INDArray labels = Nd4j.create(32, 200); File tmp = File.createTempFile("12dadsad", "dsdasds"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java index 1abb24014..1df9d4af7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java @@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace; @@ -48,24 +49,21 @@ import java.util.Arrays; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class SpecialWorkspaceTests extends BaseNd4jTest { - private DataType initialType; - public SpecialWorkspaceTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } +public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends { + private DataType initialType = Nd4j.dataType(); @AfterEach - public void shutUp() { + public void shutUp(Nd4jBackend backend) { Nd4j.getMemoryManager().setCurrentWorkspace(null); Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); Nd4j.setDataType(this.initialType); } @Test - public void testVariableTimeSeries1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariableTimeSeries1(Nd4jBackend backend) { WorkspaceConfiguration configuration = WorkspaceConfiguration .builder() .initialSize(0) @@ -172,13 +170,15 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { } @Test - public void testVariableTimeSeries2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariableTimeSeries2(Nd4jBackend backend) { WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(3.0) - .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE) - .policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).build(); + .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE) + .policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).build(); Nd4jWorkspace workspace = - (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "WS1"); + (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "WS1"); // workspace.enableDebug(true); try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { @@ -213,7 +213,9 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { } @Test - public void testViewDetach_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testViewDetach_1(Nd4jBackend backend) { WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(10000000).overallocationLimit(3.0) .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE) .policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).build(); @@ -242,7 +244,9 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { } @Test - public void testAlignment_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAlignment_1(Nd4jBackend backend) { WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "WS132143452343"); @@ -263,7 +267,9 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { } @Test - public void testNoOpExecution_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNoOpExecution_1(Nd4jBackend backend) { val configuration = WorkspaceConfiguration.builder().initialSize(10000000).overallocationLimit(3.0) .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE) .policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).build(); @@ -300,6 +306,8 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWorkspaceOrder_1(){ WorkspaceConfiguration conf = WorkspaceConfiguration.builder() .initialSize(1_000_000) @@ -335,6 +343,8 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMmapedWorkspaceLimits_1() throws Exception { if (!Nd4j.getEnvironment().isCPU()) return; @@ -359,6 +369,8 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMmapedWorkspace_Path_Limits_1() throws Exception { if (!Nd4j.getEnvironment().isCPU()) return; @@ -383,6 +395,8 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testDeleteMappedFile_1() throws Exception { if (!Nd4j.getEnvironment().isCPU()) return; @@ -406,29 +420,31 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { @Test() public void testDeleteMappedFile_2() throws Exception { - assertThrows(IllegalArgumentException.class,() -> { - if (!Nd4j.getEnvironment().isCPU()) - throw new IllegalArgumentException("Don't try to run on CUDA"); + assertThrows(IllegalArgumentException.class,() -> { + if (!Nd4j.getEnvironment().isCPU()) + throw new IllegalArgumentException("Don't try to run on CUDA"); - val tmpFile = Files.createTempFile("some", "file"); - val mmap = WorkspaceConfiguration.builder() - .initialSize(200 * 1024L * 1024L) // 200mbs - .tempFilePath(tmpFile.toAbsolutePath().toString()) - .policyLocation(LocationPolicy.MMAP) - .build(); + val tmpFile = Files.createTempFile("some", "file"); + val mmap = WorkspaceConfiguration.builder() + .initialSize(200 * 1024L * 1024L) // 200mbs + .tempFilePath(tmpFile.toAbsolutePath().toString()) + .policyLocation(LocationPolicy.MMAP) + .build(); - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { - val x = Nd4j.rand(DataType.FLOAT, 1024); - } + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { + val x = Nd4j.rand(DataType.FLOAT, 1024); + } - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - Files.delete(tmpFile); - }); + Files.delete(tmpFile); + }); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMigrateToWorkspace(){ val src = Nd4j.createFromArray (1L,2L); val wsConf = new WorkspaceConfiguration().builder().build(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java index 7d6141bfd..595e60b2b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java @@ -25,9 +25,10 @@ import lombok.val; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; @@ -51,72 +52,67 @@ import java.util.concurrent.CopyOnWriteArrayList; import static org.junit.jupiter.api.Assertions.*; @Slf4j -@RunWith(Parameterized.class) -public class WorkspaceProviderTests extends BaseNd4jTest { + +public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { private static final WorkspaceConfiguration basicConfiguration = WorkspaceConfiguration.builder().initialSize(81920) - .overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.NONE) - .policyMirroring(MirroringPolicy.FULL).policyAllocation(AllocationPolicy.OVERALLOCATE).build(); + .overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.NONE) + .policyMirroring(MirroringPolicy.FULL).policyAllocation(AllocationPolicy.OVERALLOCATE).build(); private static final WorkspaceConfiguration bigConfiguration = WorkspaceConfiguration.builder() - .initialSize(20 * 1024 * 1024L).overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL) - .policyLearning(LearningPolicy.NONE).policyMirroring(MirroringPolicy.FULL) - .policyAllocation(AllocationPolicy.OVERALLOCATE).build(); + .initialSize(20 * 1024 * 1024L).overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL) + .policyLearning(LearningPolicy.NONE).policyMirroring(MirroringPolicy.FULL) + .policyAllocation(AllocationPolicy.OVERALLOCATE).build(); private static final WorkspaceConfiguration loopConfiguration = WorkspaceConfiguration.builder().initialSize(0) - .overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.OVER_TIME) - .policyMirroring(MirroringPolicy.FULL).policyAllocation(AllocationPolicy.STRICT).build(); + .overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.OVER_TIME) + .policyMirroring(MirroringPolicy.FULL).policyAllocation(AllocationPolicy.STRICT).build(); private static final WorkspaceConfiguration delayedConfiguration = WorkspaceConfiguration.builder().initialSize(0) - .overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.OVER_TIME) - .policyMirroring(MirroringPolicy.FULL).cyclesBeforeInitialization(3) - .policyAllocation(AllocationPolicy.STRICT).build(); + .overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.OVER_TIME) + .policyMirroring(MirroringPolicy.FULL).cyclesBeforeInitialization(3) + .policyAllocation(AllocationPolicy.STRICT).build(); private static final WorkspaceConfiguration reallocateConfiguration = WorkspaceConfiguration.builder() - .initialSize(0).overallocationLimit(0.1).policySpill(SpillPolicy.REALLOCATE) - .policyLearning(LearningPolicy.OVER_TIME).policyMirroring(MirroringPolicy.FULL) - .policyAllocation(AllocationPolicy.STRICT).build(); + .initialSize(0).overallocationLimit(0.1).policySpill(SpillPolicy.REALLOCATE) + .policyLearning(LearningPolicy.OVER_TIME).policyMirroring(MirroringPolicy.FULL) + .policyAllocation(AllocationPolicy.STRICT).build(); private static final WorkspaceConfiguration reallocateDelayedConfiguration = WorkspaceConfiguration.builder() - .initialSize(0).overallocationLimit(0.1).policySpill(SpillPolicy.REALLOCATE) - .cyclesBeforeInitialization(3).policyLearning(LearningPolicy.OVER_TIME) - .policyMirroring(MirroringPolicy.FULL).policyAllocation(AllocationPolicy.STRICT).build(); + .initialSize(0).overallocationLimit(0.1).policySpill(SpillPolicy.REALLOCATE) + .cyclesBeforeInitialization(3).policyLearning(LearningPolicy.OVER_TIME) + .policyMirroring(MirroringPolicy.FULL).policyAllocation(AllocationPolicy.STRICT).build(); private static final WorkspaceConfiguration reallocateUnspecifiedConfiguration = WorkspaceConfiguration.builder() - .initialSize(0).overallocationLimit(0.0).policySpill(SpillPolicy.REALLOCATE) - .policyLearning(LearningPolicy.OVER_TIME).policyMirroring(MirroringPolicy.FULL) - .policyAllocation(AllocationPolicy.OVERALLOCATE).policyReset(ResetPolicy.BLOCK_LEFT).build(); + .initialSize(0).overallocationLimit(0.0).policySpill(SpillPolicy.REALLOCATE) + .policyLearning(LearningPolicy.OVER_TIME).policyMirroring(MirroringPolicy.FULL) + .policyAllocation(AllocationPolicy.OVERALLOCATE).policyReset(ResetPolicy.BLOCK_LEFT).build(); private static final WorkspaceConfiguration firstConfiguration = WorkspaceConfiguration.builder().initialSize(0) - .overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL) - .policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL) - .policyAllocation(AllocationPolicy.STRICT).build(); + .overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL) + .policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL) + .policyAllocation(AllocationPolicy.STRICT).build(); private static final WorkspaceConfiguration circularConfiguration = WorkspaceConfiguration.builder() - .minSize(10 * 1024L * 1024L).overallocationLimit(1.0).policySpill(SpillPolicy.EXTERNAL) - .policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL) - .policyAllocation(AllocationPolicy.STRICT).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).build(); + .minSize(10 * 1024L * 1024L).overallocationLimit(1.0).policySpill(SpillPolicy.EXTERNAL) + .policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL) + .policyAllocation(AllocationPolicy.STRICT).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).build(); private static final WorkspaceConfiguration adsiConfiguration = - WorkspaceConfiguration.builder().overallocationLimit(3.0).policySpill(SpillPolicy.REALLOCATE) - .policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL) - .policyAllocation(AllocationPolicy.OVERALLOCATE) - .policyReset(ResetPolicy.ENDOFBUFFER_REACHED).build(); + WorkspaceConfiguration.builder().overallocationLimit(3.0).policySpill(SpillPolicy.REALLOCATE) + .policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL) + .policyAllocation(AllocationPolicy.OVERALLOCATE) + .policyReset(ResetPolicy.ENDOFBUFFER_REACHED).build(); - DataType initialType; - - public WorkspaceProviderTests(Nd4jBackend backend) { - super(backend); - this.initialType = Nd4j.dataType(); - } + DataType initialType = Nd4j.dataType(); @AfterEach - public void shutUp() { + public void shutUp(Nd4jBackend backend) { Nd4j.getMemoryManager().setCurrentWorkspace(null); Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); Nd4j.setDataType(this.initialType); @@ -128,21 +124,23 @@ public class WorkspaceProviderTests extends BaseNd4jTest { * @throws Exception */ @Test - public void testUnboundedLoop2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnboundedLoop2(Nd4jBackend backend) { WorkspaceConfiguration configuration = - WorkspaceConfiguration.builder().initialSize(0).policyReset(ResetPolicy.ENDOFBUFFER_REACHED) - .policyAllocation(AllocationPolicy.OVERALLOCATE).overallocationLimit(4.0) - .policyLearning(LearningPolicy.OVER_TIME).cyclesBeforeInitialization(5).build(); + WorkspaceConfiguration.builder().initialSize(0).policyReset(ResetPolicy.ENDOFBUFFER_REACHED) + .policyAllocation(AllocationPolicy.OVERALLOCATE).overallocationLimit(4.0) + .policyLearning(LearningPolicy.OVER_TIME).cyclesBeforeInitialization(5).build(); Nd4jWorkspace ws1 = - (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "ITER"); + (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "ITER"); long requiredMemory = 100 * Nd4j.sizeOfDataType(); long shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8)); for (int x = 0; x < 100; x++) { try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) { + .getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) { INDArray array = Nd4j.create(100); } @@ -163,26 +161,28 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testUnboundedLoop1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testUnboundedLoop1(Nd4jBackend backend) { WorkspaceConfiguration configuration = WorkspaceConfiguration.builder() - .initialSize(100 * 100 * Nd4j.sizeOfDataType()).policyReset(ResetPolicy.ENDOFBUFFER_REACHED) - .policyAllocation(AllocationPolicy.STRICT).build(); + .initialSize(100 * 100 * Nd4j.sizeOfDataType()).policyReset(ResetPolicy.ENDOFBUFFER_REACHED) + .policyAllocation(AllocationPolicy.STRICT).build(); for (int x = 0; x < 100; x++) { try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) { + .getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) { INDArray array = Nd4j.create(100); } Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, - "ITER"); + "ITER"); assertEquals((x + 1) * 100 * Nd4j.sizeOfDataType(), ws1.getPrimaryOffset()); } Nd4jWorkspace ws1 = - (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "ITER"); + (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "ITER"); assertEquals(100 * 100 * Nd4j.sizeOfDataType(), ws1.getPrimaryOffset()); // just to trigger reset @@ -197,18 +197,17 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMultithreading1() throws Exception { final List workspaces = new CopyOnWriteArrayList<>(); Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); Thread[] threads = new Thread[20]; for (int x = 0; x < threads.length; x++) { - threads[x] = new Thread(new Runnable() { - @Override - public void run() { - MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(); - workspaces.add(workspace); - } + threads[x] = new Thread(() -> { + MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(); + workspaces.add(workspace); }); threads[x].start(); @@ -232,21 +231,23 @@ public class WorkspaceProviderTests extends BaseNd4jTest { @Test - public void testNestedWorkspacesOverlap2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspacesOverlap2(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); assertFalse(Nd4j.getWorkspaceManager().checkIfWorkspaceExists("WS1")); assertFalse(Nd4j.getWorkspaceManager().checkIfWorkspaceExists("WS2")); try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1") - .notifyScopeEntered()) { + .notifyScopeEntered()) { INDArray array = Nd4j.create(new double[] {6f, 3f, 1f, 9f, 21f}); INDArray array3 = null; long reqMem = 5 * Nd4j.sizeOfDataType(DataType.DOUBLE); assertEquals(reqMem + reqMem % 16, ws1.getPrimaryOffset()); try (Nd4jWorkspace ws2 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS2") - .notifyScopeEntered()) { + .notifyScopeEntered()) { INDArray array2 = Nd4j.create(new double[] {1f, 2f, 3f, 4f, 5f}); @@ -255,7 +256,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { assertEquals(reqMem + reqMem % 16, ws2.getPrimaryOffset()); try (Nd4jWorkspace ws3 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1") - .notifyScopeBorrowed()) { + .notifyScopeBorrowed()) { assertTrue(ws1 == ws3); assertTrue(ws1 == Nd4j.getMemoryManager().getCurrentWorkspace()); @@ -281,7 +282,9 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testNestedWorkspacesOverlap1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspacesOverlap1(Nd4jBackend backend) { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1").notifyScopeEntered()) { @@ -298,7 +301,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { assertEquals(reqMem + (Nd4jWorkspace.alignmentBase - reqMem % Nd4jWorkspace.alignmentBase), ws2.getPrimaryOffset()); try (Nd4jWorkspace ws3 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1") - .notifyScopeBorrowed()) { + .notifyScopeBorrowed()) { assertTrue(ws1 == ws3); INDArray array3 = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); @@ -313,6 +316,8 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWorkspacesSerde3() throws Exception { INDArray array = Nd4j.create(10).assign(1.0); INDArray restored = null; @@ -322,7 +327,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { Nd4j.write(array, dos); try (Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getAndActivateWorkspace(basicConfiguration, "WS_1")) { + .getAndActivateWorkspace(basicConfiguration, "WS_1")) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { workspace.enableDebug(true); @@ -345,6 +350,8 @@ public class WorkspaceProviderTests extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWorkspacesSerde2() throws Exception { INDArray array = Nd4j.create(10).assign(1.0); INDArray restored = null; @@ -354,7 +361,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { Nd4j.write(array, dos); try (Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getAndActivateWorkspace(basicConfiguration, "WS_1")) { + .getAndActivateWorkspace(basicConfiguration, "WS_1")) { workspace.enableDebug(true); ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); @@ -373,6 +380,8 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWorkspacesSerde1() throws Exception { int[] shape = new int[] {17, 57, 79}; INDArray array = Nd4j.create(shape).assign(1.0); @@ -397,9 +406,11 @@ public class WorkspaceProviderTests extends BaseNd4jTest { @Test - public void testCircularBufferReset1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCircularBufferReset1(Nd4jBackend backend) { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(circularConfiguration, "WSR_1"); + .getWorkspaceForCurrentThread(circularConfiguration, "WSR_1"); try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WSR_1")) { Nd4j.create(10000); @@ -429,9 +440,11 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testVariableInput1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVariableInput1(Nd4jBackend backend) { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(adsiConfiguration, "ADSI"); + .getWorkspaceForCurrentThread(adsiConfiguration, "ADSI"); INDArray array1 = null; INDArray array2 = null; @@ -517,13 +530,15 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testReallocate3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReallocate3(Nd4jBackend backend) { MemoryWorkspace workspace = Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(reallocateUnspecifiedConfiguration, "WS_1"); + .getWorkspaceForCurrentThread(reallocateUnspecifiedConfiguration, "WS_1"); for (int i = 1; i <= 10; i++) { try (MemoryWorkspace ws = Nd4j.getWorkspaceManager() - .getAndActivateWorkspace(reallocateUnspecifiedConfiguration, "WS_1")) { + .getAndActivateWorkspace(reallocateUnspecifiedConfiguration, "WS_1")) { INDArray array = Nd4j.create(100 * i); } @@ -537,7 +552,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { for (int i = 10; i > 0; i--) { try (MemoryWorkspace ws = Nd4j.getWorkspaceManager() - .getAndActivateWorkspace(reallocateUnspecifiedConfiguration, "WS_1")) { + .getAndActivateWorkspace(reallocateUnspecifiedConfiguration, "WS_1")) { INDArray array = Nd4j.create(100 * i); } } @@ -547,13 +562,15 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testReallocate2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReallocate2(Nd4jBackend backend) { MemoryWorkspace workspace = - Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(reallocateDelayedConfiguration, "WS_1"); + Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(reallocateDelayedConfiguration, "WS_1"); for (int i = 1; i <= 10; i++) { try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateDelayedConfiguration, - "WS_1")) { + "WS_1")) { INDArray array = Nd4j.create(100 * i); } @@ -565,17 +582,19 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testCircularLearning1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCircularLearning1(Nd4jBackend backend) { INDArray array1; INDArray array2; for (int i = 0; i < 2; i++) { try (MemoryWorkspace workspace = - Nd4j.getWorkspaceManager().getAndActivateWorkspace(circularConfiguration, "WSX")) { + Nd4j.getWorkspaceManager().getAndActivateWorkspace(circularConfiguration, "WSX")) { array1 = Nd4j.create(10).assign(1); } Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(circularConfiguration, "WSX"); + .getWorkspaceForCurrentThread(circularConfiguration, "WSX"); assertEquals(10 * 1024 * 1024L, workspace.getCurrentSize()); log.info("Current step number: {}", workspace.getStepNumber()); if (i == 0) @@ -587,7 +606,9 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testReallocate1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReallocate1(Nd4jBackend backend) { try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) { INDArray array = Nd4j.create(100); } @@ -595,7 +616,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(reallocateConfiguration, "WS_1"); + .getWorkspaceForCurrentThread(reallocateConfiguration, "WS_1"); workspace.initializeWorkspace(); assertEquals(100 * Nd4j.sizeOfDataType(), workspace.getCurrentSize()); @@ -620,7 +641,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { @Test @Disabled("raver119: This test doesn't make any sense to me these days. We're borrowing from the same workspace. Why?") - public void testNestedWorkspaces11() { + public void testNestedWorkspaces11(Nd4jBackend backend) { for (int x = 1; x < 10; x++) { try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { INDArray array1 = Nd4j.create(100 * x); @@ -641,15 +662,17 @@ public class WorkspaceProviderTests extends BaseNd4jTest { @Test - public void testNestedWorkspaces10() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspaces10(Nd4jBackend backend) { for (int x = 1; x < 10; x++) { try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { INDArray array1 = Nd4j.create(100 * x); try (MemoryWorkspace ws2 = - Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { + Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { INDArray array2 = Nd4j.create(100 * x); try (MemoryWorkspace ws3 = Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(basicConfiguration, "WS_1").notifyScopeBorrowed()) { + .getWorkspaceForCurrentThread(basicConfiguration, "WS_1").notifyScopeBorrowed()) { INDArray array3 = Nd4j.create(100 * x); } @@ -660,16 +683,18 @@ public class WorkspaceProviderTests extends BaseNd4jTest { @Test - public void testNestedWorkspaces9() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspaces9(Nd4jBackend backend) { for (int x = 1; x < 10; x++) { try (MemoryWorkspace ws = - Nd4j.getWorkspaceManager().getAndActivateWorkspace(delayedConfiguration, "WS_1")) { + Nd4j.getWorkspaceManager().getAndActivateWorkspace(delayedConfiguration, "WS_1")) { INDArray array = Nd4j.create(100 * x); } } Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(delayedConfiguration, "WS_1"); + .getWorkspaceForCurrentThread(delayedConfiguration, "WS_1"); workspace.initializeWorkspace(); assertEquals(300 * Nd4j.sizeOfDataType(), workspace.getCurrentSize()); @@ -677,7 +702,9 @@ public class WorkspaceProviderTests extends BaseNd4jTest { @Test - public void testNestedWorkspaces8() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspaces8(Nd4jBackend backend) { try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(loopConfiguration, "WS_1")) { INDArray array = Nd4j.create(100); } @@ -685,7 +712,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(loopConfiguration, "WS_1"); + .getWorkspaceForCurrentThread(loopConfiguration, "WS_1"); workspace.initializeWorkspace(); assertEquals(100 * Nd4j.sizeOfDataType(), workspace.getCurrentSize()); @@ -700,9 +727,11 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testNestedWorkspaces7() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspaces7(Nd4jBackend backend) { try (Nd4jWorkspace wsExternal = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getAndActivateWorkspace(basicConfiguration, "External")) { + .getAndActivateWorkspace(basicConfiguration, "External")) { INDArray array1 = Nd4j.create(10); INDArray array2 = null; INDArray array3 = null; @@ -711,12 +740,12 @@ public class WorkspaceProviderTests extends BaseNd4jTest { try (Nd4jWorkspace wsFeedForward = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getAndActivateWorkspace(basicConfiguration, "FeedForward")) { + .getAndActivateWorkspace(basicConfiguration, "FeedForward")) { array2 = Nd4j.create(10); assertEquals(true, array2.isAttached()); try (Nd4jWorkspace borrowed = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread("External").notifyScopeBorrowed()) { + .getWorkspaceForCurrentThread("External").notifyScopeBorrowed()) { array3 = Nd4j.create(10); assertTrue(wsExternal == array3.data().getParentWorkspace()); @@ -740,10 +769,12 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testNestedWorkspaces6() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspaces6(Nd4jBackend backend) { try (Nd4jWorkspace wsExternal = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getAndActivateWorkspace(firstConfiguration, "External")) { + .getAndActivateWorkspace(firstConfiguration, "External")) { INDArray array1 = Nd4j.create(10); INDArray array2 = null; INDArray array3 = null; @@ -751,12 +782,12 @@ public class WorkspaceProviderTests extends BaseNd4jTest { try (Nd4jWorkspace wsFeedForward = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getAndActivateWorkspace(firstConfiguration, "FeedForward")) { + .getAndActivateWorkspace(firstConfiguration, "FeedForward")) { array2 = Nd4j.create(10); assertEquals(true, array2.isAttached()); try (Nd4jWorkspace borrowed = (Nd4jWorkspace) Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread("External").notifyScopeBorrowed()) { + .getWorkspaceForCurrentThread("External").notifyScopeBorrowed()) { array3 = Nd4j.create(10); assertTrue(wsExternal == array3.data().getParentWorkspace()); @@ -778,14 +809,16 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testNestedWorkspaces5() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspaces5(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1") - .notifyScopeEntered()) { + .notifyScopeEntered()) { INDArray array1 = Nd4j.create(100); try (Nd4jWorkspace ws2 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1") - .notifyScopeEntered()) { + .notifyScopeEntered()) { INDArray array2 = Nd4j.create(100); } @@ -803,20 +836,22 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testNestedWorkspaces4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspaces4(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1") - .notifyScopeEntered()) { + .notifyScopeEntered()) { INDArray array1 = Nd4j.create(100); try (Nd4jWorkspace ws2 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS2") - .notifyScopeEntered()) { + .notifyScopeEntered()) { INDArray array2 = Nd4j.create(100); try (Nd4jWorkspace ws3 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS3") - .notifyScopeEntered()) { + .notifyScopeEntered()) { INDArray array3 = Nd4j.create(100); assertEquals(100 * Nd4j.sizeOfDataType(), ws1.getPrimaryOffset()); @@ -847,13 +882,15 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testNestedWorkspaces3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspaces3(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); // We open top-level workspace try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1") - .notifyScopeEntered()) { + .notifyScopeEntered()) { INDArray array1 = Nd4j.create(100); @@ -861,7 +898,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { // we open first nested workspace try (Nd4jWorkspace ws2 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS2") - .notifyScopeEntered()) { + .notifyScopeEntered()) { assertEquals(0 * Nd4j.sizeOfDataType(), ws2.getPrimaryOffset()); INDArray array2 = Nd4j.create(100); @@ -872,7 +909,7 @@ public class WorkspaceProviderTests extends BaseNd4jTest { // and second nexted workspace try (Nd4jWorkspace ws3 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS3") - .notifyScopeEntered()) { + .notifyScopeEntered()) { assertEquals(0 * Nd4j.sizeOfDataType(), ws3.getPrimaryOffset()); INDArray array2 = Nd4j.create(100); @@ -893,11 +930,13 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testNestedWorkspaces2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspaces2(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1") - .notifyScopeEntered()) { + .notifyScopeEntered()) { INDArray array1 = Nd4j.create(100); @@ -922,19 +961,21 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testNestedWorkspaces1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNestedWorkspaces1(Nd4jBackend backend) { Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1") - .notifyScopeEntered()) { + .notifyScopeEntered()) { INDArray array1 = Nd4j.create(100); assertEquals(100 * Nd4j.sizeOfDataType(), ws1.getPrimaryOffset()); try (Nd4jWorkspace ws2 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS2") - .notifyScopeEntered()) { + .notifyScopeEntered()) { assertEquals(0 * Nd4j.sizeOfDataType(), ws2.getPrimaryOffset()); INDArray array2 = Nd4j.create(100); @@ -950,7 +991,9 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test - public void testNewWorkspace1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNewWorkspace1(Nd4jBackend backend) { MemoryWorkspace workspace1 = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(); assertNotEquals(null, workspace1); @@ -961,20 +1004,19 @@ public class WorkspaceProviderTests extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWorkspaceGc_1() throws Exception { for (int e = 0; e < 10; e++) { val f = e; - val t = new Thread(new Runnable() { - @Override - public void run() { - val wsConf = WorkspaceConfiguration.builder() - .initialSize(1000000).build(); - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsConf, "SomeRandomName999" + f)) { - val array = Nd4j.create(2, 2); - } - //Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + val t = new Thread(() -> { + val wsConf = WorkspaceConfiguration.builder() + .initialSize(1000000).build(); + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsConf, "SomeRandomName999" + f)) { + val array = Nd4j.create(2, 2); } + //Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); }); t.start(); t.join(); @@ -992,15 +1034,17 @@ public class WorkspaceProviderTests extends BaseNd4jTest { @Disabled @Test - public void testMemcpy1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMemcpy1(Nd4jBackend backend) { INDArray warmUp = Nd4j.create(100000); for (int x = 0; x < 5000; x++) { warmUp.addi(0.1); } WorkspaceConfiguration configuration = - WorkspaceConfiguration.builder().policyMirroring(MirroringPolicy.HOST_ONLY) - .initialSize(1024L * 1024L * 1024L).policyLearning(LearningPolicy.NONE).build(); + WorkspaceConfiguration.builder().policyMirroring(MirroringPolicy.HOST_ONLY) + .initialSize(1024L * 1024L * 1024L).policyLearning(LearningPolicy.NONE).build(); INDArray array = Nd4j.createUninitialized(150000000); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/list/NDArrayListTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/list/NDArrayListTest.java index db3e84870..f892ec843 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/list/NDArrayListTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/list/NDArrayListTest.java @@ -21,7 +21,9 @@ package org.nd4j.list; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.factory.Nd4jBackend; import java.util.ArrayList; @@ -29,11 +31,8 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; -public class NDArrayListTest extends BaseNd4jTest { +public class NDArrayListTest extends BaseNd4jTestWithBackends { - public NDArrayListTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -41,7 +40,9 @@ public class NDArrayListTest extends BaseNd4jTest { } @Test - public void testList() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testList(Nd4jBackend backend) { NDArrayList ndArrayList = new NDArrayList(); List arrayAssertion = new ArrayList<>(); for(int i = 0; i < 11; i++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/base64/Nd4jBase64Test.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/base64/Nd4jBase64Test.java index aa4fff5dc..2fe1a3a24 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/base64/Nd4jBase64Test.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/base64/Nd4jBase64Test.java @@ -21,18 +21,17 @@ package org.nd4j.serde.base64; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; -public class Nd4jBase64Test extends BaseNd4jTest { +public class Nd4jBase64Test extends BaseNd4jTestWithBackends { - public Nd4jBase64Test(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -40,7 +39,9 @@ public class Nd4jBase64Test extends BaseNd4jTest { } @Test - public void testBase64() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBase64(Nd4jBackend backend) throws Exception { INDArray arr = Nd4j.linspace(1, 4, 4); String base64 = Nd4jBase64.base64String(arr); INDArray from = Nd4jBase64.fromBase64(base64); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java index 78356eb3d..bb8fd4ffa 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/binary/BinarySerdeTest.java @@ -22,8 +22,10 @@ package org.nd4j.serde.binary; import org.apache.commons.lang3.time.StopWatch; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.OpValidationSuite; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -38,11 +40,8 @@ import java.util.UUID; import static org.junit.jupiter.api.Assertions.*; -public class BinarySerdeTest extends BaseNd4jTest { +public class BinarySerdeTest extends BaseNd4jTestWithBackends { - public BinarySerdeTest(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -50,7 +49,9 @@ public class BinarySerdeTest extends BaseNd4jTest { } @Test - public void testToAndFrom() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToAndFrom(Nd4jBackend backend) { INDArray arr = Nd4j.scalar(1.0); ByteBuffer buffer = BinarySerde.toByteBuffer(arr); INDArray back = BinarySerde.toArray(buffer); @@ -58,7 +59,9 @@ public class BinarySerdeTest extends BaseNd4jTest { } @Test - public void testToAndFromHeapBuffer() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToAndFromHeapBuffer(Nd4jBackend backend) { INDArray arr = Nd4j.scalar(1.0); ByteBuffer buffer = BinarySerde.toByteBuffer(arr); ByteBuffer heapBuffer = ByteBuffer.allocate(buffer.remaining()); @@ -68,7 +71,9 @@ public class BinarySerdeTest extends BaseNd4jTest { } @Test - public void testToAndFromCompressed() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToAndFromCompressed(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); //Failing 2019/01/24 INDArray arr = Nd4j.scalar(1.0); INDArray compress = Nd4j.getCompressor().compress(arr, "GZIP"); @@ -82,7 +87,9 @@ public class BinarySerdeTest extends BaseNd4jTest { @Test - public void testToAndFromCompressedLarge() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testToAndFromCompressedLarge(Nd4jBackend backend) { OpValidationSuite.ignoreFailing(); //Failing 2019/01/24 INDArray arr = Nd4j.zeros((int) 1e7); INDArray compress = Nd4j.getCompressor().compress(arr, "GZIP"); @@ -96,7 +103,9 @@ public class BinarySerdeTest extends BaseNd4jTest { @Test - public void testReadWriteFile() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReadWriteFile(Nd4jBackend backend) throws Exception { File tmpFile = new File(System.getProperty("java.io.tmpdir"), "ndarraytmp-" + UUID.randomUUID().toString() + " .bin"); tmpFile.deleteOnExit(); @@ -107,7 +116,9 @@ public class BinarySerdeTest extends BaseNd4jTest { } @Test - public void testReadShapeFile() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReadShapeFile(Nd4jBackend backend) throws Exception { File tmpFile = new File(System.getProperty("java.io.tmpdir"), "ndarraytmp-" + UUID.randomUUID().toString() + " .bin"); tmpFile.deleteOnExit(); @@ -119,7 +130,9 @@ public class BinarySerdeTest extends BaseNd4jTest { } @Test - public void timeOldVsNew() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void timeOldVsNew(Nd4jBackend backend) throws Exception { int numTrials = 1000; long oldTotal = 0; long newTotal = 0; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/smoketests/SmokeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/smoketests/SmokeTest.java index 4f658e4fa..89029ec9d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/smoketests/SmokeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/smoketests/SmokeTest.java @@ -25,6 +25,8 @@ package org.nd4j.smoketests; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -35,6 +37,8 @@ public class SmokeTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBasic() { Nd4j.getEnvironment().setDebug(true); Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder() diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java index 095818e1c..8538a4391 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/systeminfo/TestSystemInfo.java @@ -21,10 +21,14 @@ package org.nd4j.systeminfo; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.common.tests.BaseND4JTest; public class TestSystemInfo extends BaseND4JTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSystemInfo(){ SystemInfo.printSystemInfo(); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/kotlin/org/nd4j/linalg/custom/CustomOpTensorflowInteropTests.kt b/nd4j/nd4j-backends/nd4j-tests/src/test/kotlin/org/nd4j/linalg/custom/CustomOpTensorflowInteropTests.kt deleted file mode 100644 index 6f728f79d..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/kotlin/org/nd4j/linalg/custom/CustomOpTensorflowInteropTests.kt +++ /dev/null @@ -1,118 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.nd4j.linalg.custom - -import junit.framework.Assert.assertEquals -import org.junit.jupiter.api.Disabled -import org.junit.Test -import org.nd4j.linalg.api.buffer.DataType -import org.nd4j.linalg.api.ops.impl.image.CropAndResize -import org.nd4j.linalg.factory.Nd4j -import org.nd4j.samediff.frameworkimport.tensorflow.* -import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter -import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph -import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner - -class CustomOpTensorflowInteropTests { - - @Test - @Disabled("Tensorflow expects different shape") - fun testCropAndResize() { - val image = Nd4j.createUninitialized(DataType.FLOAT, 1, 2, 2, 1) - val boxes = Nd4j.createFromArray(*floatArrayOf(1f, 2f, 3f, 4f)).reshape(1, 4) - val box_indices = Nd4j.createFromArray(*intArrayOf(0)) - val crop_size = Nd4j.createFromArray(*intArrayOf(1, 2)).reshape( 2) - val imageNode = NodeDef { - op = "Placeholder" - name = "image" - Attribute("dtype", AttrValue { - type = org.tensorflow.framework.DataType.DT_FLOAT - }) - } - - val boxesNode = NodeDef { - op = "Placeholder" - name = "boxes" - Attribute("dtype", AttrValue { - type = org.tensorflow.framework.DataType.DT_FLOAT - }) - } - - val boxIndicesNode = NodeDef { - op = "Placeholder" - name = "boxIndices" - Attribute("dtype", AttrValue { - type = org.tensorflow.framework.DataType.DT_INT32 - }) - } - - val cropSizesNode = NodeDef { - op = "Placeholder" - name = "cropSize" - Attribute("dtype", AttrValue { - type = org.tensorflow.framework.DataType.DT_INT32 - }) - } - - - val opNode = NodeDef { - op = "CropAndResize" - name = "output" - Input("image") - Input("boxes") - Input("boxIndices") - Input("cropSize") - Attribute("extrapolation_value", AttrValue { - f = 0.5f - }) - Attribute("T", AttrValue { - type = org.tensorflow.framework.DataType.DT_FLOAT - }) - } - - val graph = GraphDef { - Node(imageNode) - Node(boxesNode) - Node(boxIndicesNode) - Node(cropSizesNode) - Node(opNode) - - } - - val importer = TensorflowFrameworkImporter() - val irGraph = TensorflowIRGraph(graph,importer.opDefList,importer.registry) - val runner = TensorflowIRGraphRunner(irGraph,listOf("image","boxes","boxIndices","cropSize"),listOf("output")) - val tfResult = runner.run(mapOf("image" to image,"boxes" to boxes,"boxIndices" to box_indices,"cropSize" to crop_size)) - val outputArr = tfResult["output"] - //Output shape mismatch - TF [2, 2, 1, 1] vs SD: [1, 2, 1, 1] - val output = Nd4j.create(DataType.FLOAT, 2, 2, 1, 1) - Nd4j.exec( - CropAndResize( - image, boxes, box_indices, crop_size, CropAndResize.Method.BILINEAR, 0.5, - output - ) - ) - - assertEquals(outputArr,output) - } - - -} \ No newline at end of file diff --git a/nd4j/nd4j-common-tests/pom.xml b/nd4j/nd4j-common-tests/pom.xml index 61cdbb1a3..064bebcf3 100644 --- a/nd4j/nd4j-common-tests/pom.xml +++ b/nd4j/nd4j-common-tests/pom.xml @@ -50,6 +50,11 @@ compile + + org.junit.jupiter + junit-jupiter-params + compile + org.junit.jupiter junit-jupiter diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/linalg/BaseNd4jTestWithBackends.java similarity index 78% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java rename to nd4j/nd4j-common-tests/src/main/java/org/nd4j/linalg/BaseNd4jTestWithBackends.java index c1061f1a6..c5a30ed12 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/linalg/BaseNd4jTestWithBackends.java @@ -22,7 +22,7 @@ package org.nd4j.linalg; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.BeforeEach; -import org.junit.runner.RunWith; +import org.junit.jupiter.params.provider.Arguments; import org.junit.runners.Parameterized; import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.common.io.ReflectionUtils; @@ -31,14 +31,15 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import java.util.*; +import java.util.stream.Stream; /** * Base Nd4j test * @author Adam Gibson */ -@RunWith(Parameterized.class) + @Slf4j -public abstract class BaseNd4jTest extends BaseND4JTest { +public abstract class BaseNd4jTestWithBackends extends BaseND4JTest { private static List BACKENDS = new ArrayList<>(); static { List backendsToRun = Nd4jTestSuite.backendsToRun(); @@ -56,29 +57,10 @@ public abstract class BaseNd4jTest extends BaseND4JTest { protected String name; public final static String DEFAULT_BACKEND = "org.nd4j.linalg.defaultbackend"; - public BaseNd4jTest() { - this("", getDefaultBackend()); - } - public BaseNd4jTest(String name) { - this(name, getDefaultBackend()); - } - public BaseNd4jTest(String name, Nd4jBackend backend) { - this.backend = backend; - this.name = name; - } - - public BaseNd4jTest(Nd4jBackend backend) { - this(backend.getClass().getName() + UUID.randomUUID().toString(), backend); - } - - @Parameterized.Parameters(name = "{index}: backend({0})={1}") - public static Collection configs() { - List ret = new ArrayList<>(); - for (Nd4jBackend backend : BACKENDS) - ret.add(new Object[] {backend}); - return ret; + public static Stream configs() { + return BACKENDS.stream().map(input -> Arguments.of(input)); } @BeforeEach @@ -87,7 +69,7 @@ public abstract class BaseNd4jTest extends BaseND4JTest { } /** - * Get the default backend (jblas) + * Get the default backend (nd4j) * The default backend can be overridden by also passing: * -Dorg.nd4j.linalg.defaultbackend=your.backend.classname * @return the default backend based on the diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestSuite.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/linalg/Nd4jTestSuite.java similarity index 87% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestSuite.java rename to nd4j/nd4j-common-tests/src/main/java/org/nd4j/linalg/Nd4jTestSuite.java index e9c6c3463..255dd757f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestSuite.java +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/linalg/Nd4jTestSuite.java @@ -20,7 +20,6 @@ package org.nd4j.linalg; -import org.junit.runners.BlockJUnit4ClassRunner; import org.nd4j.common.config.ND4JClassLoading; import org.nd4j.linalg.factory.Nd4jBackend; @@ -28,7 +27,7 @@ import java.util.ArrayList; import java.util.List; import java.util.ServiceLoader; -public class Nd4jTestSuite extends BlockJUnit4ClassRunner { +public class Nd4jTestSuite { //the system property for what backends should run public final static String BACKENDS_TO_LOAD = "backends"; private static List BACKENDS = new ArrayList<>(); @@ -39,14 +38,7 @@ public class Nd4jTestSuite extends BlockJUnit4ClassRunner { } } - /** - * Only called reflectively. Do not use programmatically. - * - * @param klass - */ - public Nd4jTestSuite(Class klass) throws Throwable { - super(klass); - } + /** * Based on the jvm arguments, an empty list is returned diff --git a/nd4j/samediff-import/pom.xml b/nd4j/samediff-import/pom.xml index 931016732..cd4585698 100644 --- a/nd4j/samediff-import/pom.xml +++ b/nd4j/samediff-import/pom.xml @@ -182,4 +182,10 @@ + + + testresources + + + diff --git a/nd4j/samediff-import/samediff-import-api/pom.xml b/nd4j/samediff-import/samediff-import-api/pom.xml index 80ff25f11..1ff787d38 100644 --- a/nd4j/samediff-import/samediff-import-api/pom.xml +++ b/nd4j/samediff-import/samediff-import-api/pom.xml @@ -151,5 +151,9 @@ - + + + testresources + + diff --git a/nd4j/samediff-import/samediff-import-onnx/pom.xml b/nd4j/samediff-import/samediff-import-onnx/pom.xml index 68c80e38d..212b76cb0 100644 --- a/nd4j/samediff-import/samediff-import-onnx/pom.xml +++ b/nd4j/samediff-import/samediff-import-onnx/pom.xml @@ -73,5 +73,9 @@ - + + + testresources + + diff --git a/nd4j/samediff-import/samediff-import-tensorflow/pom.xml b/nd4j/samediff-import/samediff-import-tensorflow/pom.xml index 334a75bac..dc4a5f5b6 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/pom.xml +++ b/nd4j/samediff-import/samediff-import-tensorflow/pom.xml @@ -52,12 +52,40 @@ + + org.springframework + spring-core + 5.0.2.RELEASE + test + + + org.junit.jupiter + junit-jupiter-api + + + org.junit.jupiter + junit-jupiter-engine + + + org.junit.jupiter + junit-jupiter-params + org.nd4j samediff-import-api ${project.version} + + org.nd4j + nd4j-common-tests + ${project.version} + test + - + + + testresources + + diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/ByteOrderTests.java similarity index 77% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/ByteOrderTests.java index 5b5470bda..78b46eb60 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ByteOrderTests.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/ByteOrderTests.java @@ -25,10 +25,11 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.graph.FlatArray; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; @@ -41,12 +42,9 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@RunWith(Parameterized.class) -public class ByteOrderTests extends BaseNd4jTest { - public ByteOrderTests(Nd4jBackend backend) { - super(backend); - } +public class ByteOrderTests extends BaseNd4jTestWithBackends { + @AfterEach public void tearDown() { @@ -55,7 +53,9 @@ public class ByteOrderTests extends BaseNd4jTest { } @Test - public void testByteArrayOrder1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testByteArrayOrder1(Nd4jBackend backend) { val ndarray = Nd4j.create(DataType.FLOAT, 2).assign(1); assertEquals(DataType.FLOAT, ndarray.data().dataType()); @@ -66,7 +66,9 @@ public class ByteOrderTests extends BaseNd4jTest { } @Test - public void testByteArrayOrder2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testByteArrayOrder2(Nd4jBackend backend) { val original = Nd4j.linspace(1, 25, 25, DataType.FLOAT).reshape(5, 5); val bufferBuilder = new FlatBufferBuilder(0); @@ -82,7 +84,9 @@ public class ByteOrderTests extends BaseNd4jTest { @Test - public void testByteArrayOrder3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testByteArrayOrder3(Nd4jBackend backend) { val original = Nd4j.linspace(1, 25, 25, DataType.FLOAT).reshape('f', 5, 5); val bufferBuilder = new FlatBufferBuilder(0); @@ -97,7 +101,9 @@ public class ByteOrderTests extends BaseNd4jTest { } @Test - public void testShapeStridesOf1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShapeStridesOf1(Nd4jBackend backend) { val buffer = new int[]{2, 5, 5, 5, 1, 0, 1, 99}; val shape = Shape.shapeOf(buffer); @@ -108,7 +114,9 @@ public class ByteOrderTests extends BaseNd4jTest { } @Test - public void testShapeStridesOf2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testShapeStridesOf2(Nd4jBackend backend) { val buffer = new int[]{3, 5, 5, 5, 25, 5, 1, 0, 1, 99}; val shape = Shape.shapeOf(buffer); @@ -119,7 +127,9 @@ public class ByteOrderTests extends BaseNd4jTest { } @Test - public void testScalarEncoding() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testScalarEncoding(Nd4jBackend backend) { val scalar = Nd4j.scalar(2.0f); FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0); @@ -137,7 +147,9 @@ public class ByteOrderTests extends BaseNd4jTest { @Test - public void testVectorEncoding_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorEncoding_1(Nd4jBackend backend) { val scalar = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5}); FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0); @@ -153,7 +165,9 @@ public class ByteOrderTests extends BaseNd4jTest { } @Test - public void testVectorEncoding_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testVectorEncoding_2(Nd4jBackend backend) { val scalar = Nd4j.createFromArray(new double[]{1, 2, 3, 4, 5}); FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(0); @@ -169,7 +183,9 @@ public class ByteOrderTests extends BaseNd4jTest { } @Test - public void testStringEncoding_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testStringEncoding_1(Nd4jBackend backend) { val strings = Arrays.asList("alpha", "beta", "gamma"); val vector = Nd4j.create(strings, 3); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/ExecutionTests.java similarity index 68% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/ExecutionTests.java index b1cb771db..8cc8c8238 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/ExecutionTests.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/ExecutionTests.java @@ -24,16 +24,14 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.OpValidationSuite; -import org.nd4j.imports.tfgraphs.TFGraphTestZooModels; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.common.io.ClassPathResource; import org.nd4j.nativeblas.NativeOpsHolder; @@ -42,12 +40,9 @@ import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@RunWith(Parameterized.class) -public class ExecutionTests extends BaseNd4jTest { - public ExecutionTests(Nd4jBackend backend) { - super(backend); - } +public class ExecutionTests extends BaseNd4jTestWithBackends { + @AfterEach public void tearDown() { @@ -57,17 +52,9 @@ public class ExecutionTests extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testStoredGraph_1() throws Exception { - if(TFGraphTestZooModels.isPPC()){ - /* - Ugly hack to temporarily disable tests on PPC only on CI - Issue logged here: https://github.com/eclipse/deeplearning4j/issues/7657 - These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions - */ - log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/eclipse/deeplearning4j/issues/7657"); - OpValidationSuite.ignoreFailing(); - } - Nd4j.create(1); val tg = TFGraphMapper.importGraphTxt(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream(), null, null); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/NameTests.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/NameTests.java similarity index 73% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/NameTests.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/NameTests.java index c92370f5d..8f87ef93f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/NameTests.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/NameTests.java @@ -23,24 +23,24 @@ package org.nd4j.imports; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j -@RunWith(Parameterized.class) -public class NameTests extends BaseNd4jTest { - public NameTests(Nd4jBackend backend) { - super(backend); - } +public class NameTests extends BaseNd4jTestWithBackends { + @Test - public void testNameExtraction_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNameExtraction_1(Nd4jBackend backend) { val str = "Name"; val exp = "Name"; @@ -51,7 +51,9 @@ public class NameTests extends BaseNd4jTest { @Test - public void testNameExtraction_2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNameExtraction_2(Nd4jBackend backend) { val str = "Name_2"; val exp = "Name_2"; @@ -61,7 +63,9 @@ public class NameTests extends BaseNd4jTest { } @Test - public void testNameExtraction_3() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNameExtraction_3(Nd4jBackend backend) { val str = "Name_1:2"; val exp = "Name_1"; @@ -71,7 +75,9 @@ public class NameTests extends BaseNd4jTest { } @Test - public void testNameExtraction_4() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNameExtraction_4(Nd4jBackend backend) { val str = "Name_1:1:2"; val exp = "Name_1:1"; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/TensorFlowImportTest.java similarity index 91% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/TensorFlowImportTest.java index efb6b5820..21d502fd7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/TensorFlowImportTest.java @@ -26,8 +26,9 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + import org.nd4j.autodiff.execution.conf.ExecutionMode; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; @@ -39,7 +40,7 @@ import org.nd4j.graph.FlatGraph; import org.nd4j.graph.FlatNode; import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; @@ -67,8 +68,8 @@ import static org.junit.jupiter.api.Assertions.*; @Slf4j @Disabled -@RunWith(Parameterized.class) -public class TensorFlowImportTest extends BaseNd4jTest { + +public class TensorFlowImportTest extends BaseNd4jTestWithBackends { private static ExecutorConfiguration configuration = ExecutorConfiguration.builder() .executionMode(ExecutionMode.SEQUENTIAL) .profilingMode(OpExecutioner.ProfilingMode.DISABLED) @@ -76,9 +77,6 @@ public class TensorFlowImportTest extends BaseNd4jTest { .outputMode(OutputMode.IMPLICIT) .build(); - public TensorFlowImportTest(Nd4jBackend backend) { - super(backend); - } @Override @@ -87,22 +85,26 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @BeforeEach - public void setUp() { + public void setUp(Nd4jBackend backend) { } @AfterEach - public void tearDown() { + public void tearDown(Nd4jBackend backend) { NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); } @Test - public void testClassHolder() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testClassHolder(Nd4jBackend backend) { DifferentialFunctionClassHolder.getInstance(); } @Test - public void testSingleExample_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testSingleExample_1(Nd4jBackend backend) { val g = TFGraphMapper.importGraph(new File("C:\\Users\\raver\\Downloads\\mnist.pb")); val array = Nd4j.ones(1, 28, 28); @@ -115,11 +117,15 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test - public void testAssertImport_1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testAssertImport_1(Nd4jBackend backend) { val graph = TFGraphMapper.importGraph(new File("C:\\Users\\raver\\Downloads\\test.pb")); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testArgMaxImport_2() throws Exception { val graph = TFGraphMapper.importGraph(new ClassPathResource("/tf_graphs/examples/reductions/argmax3,4,5_-1/frozen_graph.pbtxt").getInputStream()); @@ -129,6 +135,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testArgMaxImport_1() throws Exception { val graph = TFGraphMapper.importGraph(new ClassPathResource("/tf_graphs/argmax.pb.txt").getInputStream()); @@ -141,20 +149,26 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - public void testHashEquality1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHashEquality1(Nd4jBackend backend) { long hash = HashUtil.getLongHash("Conv2D"); assertEquals(-1637140380760460323L, hash); } @Test - public void testHashEquality2() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testHashEquality2(Nd4jBackend backend) { long hash = HashUtil.getLongHash("switch"); assertEquals(-1988317239813741487L, hash); } @Test - public void testCustomOps1() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testCustomOps1(Nd4jBackend backend) { val map = Nd4j.getExecutioner().getCustomOperations(); assertTrue(map.size() > 0); @@ -236,6 +250,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLenet() throws Exception { /** * Produced with: @@ -261,12 +277,16 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIntermediate2() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_lstm.pb").getInputStream()); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIntermediate1() throws Exception { Nd4j.create(1); @@ -287,6 +307,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIntermediateLoop1() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/simple_while.pb.txt").getInputStream()); @@ -303,13 +325,15 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test @Disabled - public void testWeirdConvImport() { + public void testWeirdConvImport(Nd4jBackend backend) { val tg = TFGraphMapper.importGraph(new File("/home/agibsonccc/code/raver_tfimport_test1/profiling_conv.pb.txt")); assertNotNull(tg); } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIntermediateLoop3() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/nested_while.pb.txt").getInputStream()); @@ -484,6 +508,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testIntermediateReduction() throws Exception { Nd4j.create(1); SameDiff tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream()); @@ -550,7 +576,9 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test - public void testDefaultArgs() { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testDefaultArgs(Nd4jBackend backend) { val op = new RectifiedLinear(); val extras = op.extraArgs(); @@ -561,6 +589,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testInferShape() throws IOException { /** * node { @@ -663,6 +693,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testImportMapping1() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream()); @@ -683,6 +715,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCondMapping1() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); @@ -698,6 +732,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testCondMapping2() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream()); @@ -715,6 +751,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWhileMapping1() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); @@ -734,6 +772,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWhileMapping2() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); @@ -752,6 +792,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWhileMapping3() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream()); @@ -771,6 +813,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWhileDualMapping1() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); @@ -791,6 +835,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testWhileDualMapping2() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream()); @@ -812,6 +858,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testMixedWhileCond1() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_nested/frozen_model.pb").getInputStream()); @@ -968,6 +1016,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTensorArray_119_1() throws Exception { val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream()); assertNotNull(tg); @@ -981,6 +1031,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTensorArray_119_2() throws Exception { val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_read.pb.txt").getInputStream()); assertNotNull(tg); @@ -996,6 +1048,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTensorArray_119_3() throws Exception { Nd4j.create(1); @@ -1010,6 +1064,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testTensorArray_119_4() throws Exception { val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream()); assertNotNull(tg); @@ -1024,6 +1080,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLossImport_1() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/losses/log_loss_rank2_axis1_SUM_OVER_BATCH_SIZE/frozen_model.pb").getInputStream()); @@ -1032,6 +1090,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testG_1() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/g_08/frozen_model.pb").getInputStream()); @@ -1040,6 +1100,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testBoolImport_1() throws Exception { Nd4j.create(1); for (int e = 0; e < 1000; e++){ @@ -1053,6 +1115,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testLogical_1() throws Exception { Nd4j.create(1); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transforms/logicalxor_3,4_3,4/frozen_model.pb").getInputStream()); @@ -1061,6 +1125,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testSSD_1() throws Exception { // tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb Nd4j.create(1); @@ -1078,6 +1144,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRandomGraph() throws Exception { val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/scalar_float32/frozen_model.pb").getInputStream()); assertNotNull(tg); @@ -1086,6 +1154,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { } @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testRandomGraph2() throws Exception { val tg = TFGraphMapper.importGraph(new File("c:\\develop\\mobilenet_v2_1.0_224_frozen.pb")); assertNotNull(tg); @@ -1105,6 +1175,8 @@ public class TensorFlowImportTest extends BaseNd4jTest { @Test + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") public void testControlDependencies1() throws Exception { SameDiff sd = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/cond/cond_true/frozen_model.pb").getInputStream()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TestReverse.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/TestReverse.java similarity index 81% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TestReverse.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/TestReverse.java index 16e6de0ff..e9677b7d2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TestReverse.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/TestReverse.java @@ -21,18 +21,17 @@ package org.nd4j.imports; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -public class TestReverse extends BaseNd4jTest { +public class TestReverse extends BaseNd4jTestWithBackends { - public TestReverse(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -40,7 +39,9 @@ public class TestReverse extends BaseNd4jTest { } @Test - public void testReverse(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverse(Nd4jBackend backend) { INDArray in = Nd4j.createFromArray(new double[]{1,2,3,4,5,6}); INDArray out = Nd4j.create(DataType.DOUBLE, 6); @@ -57,7 +58,9 @@ public class TestReverse extends BaseNd4jTest { } @Test - public void testReverse2(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testReverse2(Nd4jBackend backend){ INDArray in = Nd4j.createFromArray(new double[]{1,2,3,4,5,6}); INDArray axis = Nd4j.scalar(0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java similarity index 100% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportDebugListener.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/listeners/ImportDebugListener.java similarity index 100% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportDebugListener.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/listeners/ImportDebugListener.java diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java similarity index 100% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/listeners/ImportModelDebugger.java diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/BERTGraphTest.java similarity index 97% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/BERTGraphTest.java index 2545dd8fe..79051f579 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/BERTGraphTest.java @@ -23,6 +23,8 @@ package org.nd4j.imports.tfgraphs; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; @@ -35,7 +37,7 @@ import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.tensorflow.TFImportOverride; import org.nd4j.imports.tensorflow.TFOpImportFilter; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -54,11 +56,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j @Disabled("AB 2019/05/21 - JVM Crash on linux-x86_64-cuda-9.2, linux-ppc64le-cpu - Issue #7657") -public class BERTGraphTest extends BaseNd4jTest { +public class BERTGraphTest extends BaseNd4jTestWithBackends { - public BERTGraphTest(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -66,7 +65,9 @@ public class BERTGraphTest extends BaseNd4jTest { } @Test - public void testBert() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBert(Nd4jBackend backend) throws Exception { String url = "https://dl4jdata.blob.core.windows.net/testresources/bert_mrpc_frozen_v1.zip"; File saveDir = new File(TFGraphTestZooModels.getBaseModelDir(), ".nd4jtests/bert_mrpc_frozen_v1"); @@ -275,7 +276,9 @@ public class BERTGraphTest extends BaseNd4jTest { } @Test //@Disabled //AB ignored 08/04/2019 until fixed - public void testBertTraining() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testBertTraining(Nd4jBackend backend) throws Exception { String url = "https://dl4jdata.blob.core.windows.net/testresources/bert_mrpc_frozen_v1.zip"; File saveDir = new File(TFGraphTestZooModels.getBaseModelDir(), ".nd4jtests/bert_mrpc_frozen_v1"); saveDir.mkdirs(); @@ -404,7 +407,7 @@ public class BERTGraphTest extends BaseNd4jTest { INDArray lossArr = sd.output(placeholderValues, "loss").get("loss"); assertTrue(lossArr.isScalar()); double scoreBefore = lossArr.getDouble(0); - for( int i=0; i<5; i++ ){ + for( int i = 0; i < 5; i++) { sd.fit(mds); } @@ -416,8 +419,11 @@ public class BERTGraphTest extends BaseNd4jTest { assertTrue( scoreAfter < scoreBefore,s); } - @Test @Disabled - public void writeBertUI() throws Exception { + @Test + @Disabled + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void writeBertUI(Nd4jBackend backend) throws Exception { //Test used to generate graph for visualization to work out appropriate subgraph structure to replace File f = new File("C:/Temp/TF_Graphs/mrpc_output/frozen/bert_mrpc_frozen.pb"); int minibatchSize = 4; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/CustomOpTests.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/CustomOpTests.java similarity index 84% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/CustomOpTests.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/CustomOpTests.java index 64006120e..d00c4e1bd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/CustomOpTests.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/CustomOpTests.java @@ -22,7 +22,9 @@ package org.nd4j.imports.tfgraphs; import lombok.val; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -32,11 +34,8 @@ import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -public class CustomOpTests extends BaseNd4jTest { +public class CustomOpTests extends BaseNd4jTestWithBackends { - public CustomOpTests(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -44,7 +43,9 @@ public class CustomOpTests extends BaseNd4jTest { } @Test - public void testPad(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testPad(Nd4jBackend backend){ INDArray in = Nd4j.create(DataType.FLOAT, 1, 28, 28, 264); INDArray pad = Nd4j.createFromArray(new int[][]{{0,0},{0,1},{0,1},{0,0}}); @@ -64,7 +65,9 @@ public class CustomOpTests extends BaseNd4jTest { } @Test - public void testResizeBilinearEdgeCase(){ + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testResizeBilinearEdgeCase(Nd4jBackend backend){ INDArray in = Nd4j.ones(DataType.FLOAT, 1, 1, 1, 3); INDArray size = Nd4j.createFromArray(8, 8); INDArray out = Nd4j.create(DataType.FLOAT, 1, 8, 8, 3); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/NodeReader.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/NodeReader.java similarity index 100% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/NodeReader.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/NodeReader.java diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/NodeReaderTests.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/NodeReaderTests.java similarity index 81% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/NodeReaderTests.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/NodeReaderTests.java index 268acae1c..8643ecabd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/NodeReaderTests.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/NodeReaderTests.java @@ -23,7 +23,9 @@ package org.nd4j.imports.tfgraphs; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.Test; -import org.nd4j.linalg.BaseNd4jTest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -31,11 +33,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @Slf4j -public class NodeReaderTests extends BaseNd4jTest { +public class NodeReaderTests extends BaseNd4jTestWithBackends { - public NodeReaderTests(Nd4jBackend b){ - super(b); - } @Override public char ordering(){ @@ -43,7 +42,9 @@ public class NodeReaderTests extends BaseNd4jTest { } @Test - public void testNodeReader_1() throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testNodeReader_1(Nd4jBackend backend) throws Exception { val array = NodeReader.readArray("ae_00", "BiasAdd.0"); val exp = Nd4j.create(new double[]{0.75157526, 0.73641957, 0.50457279, -0.45943720, 0.58269453, 0.10282226, -0.45269983, -0.05505687, -0.46887864, -0.05584033}, new long[]{5 ,2}); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestAllHelper.java similarity index 99% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestAllHelper.java index faeb04d22..6a069d545 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestAllHelper.java @@ -20,7 +20,8 @@ package org.nd4j.imports.tfgraphs; -import com.google.common.io.Files; +import org.nd4j.imports.listeners.ExecPrintListener; +import org.nd4j.imports.tfgraphs.listener.OpExecOrderListener; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.io.FilenameUtils; @@ -48,8 +49,6 @@ import org.nd4j.common.primitives.Pair; import org.nd4j.common.resources.strumpf.ResourceFile; import org.nd4j.common.resources.strumpf.StrumpfResolver; import org.nd4j.common.util.ArrayUtil; -import org.nd4j.imports.listeners.ExecPrintListener; -import org.nd4j.imports.tfgraphs.listener.OpExecOrderListener; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; @@ -63,6 +62,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.string.NDArrayStrings; import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter; +import org.nd4j.shade.guava.io.Files; import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; import org.springframework.core.io.support.PathMatchingResourcePatternResolver; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestAllLibnd4j.java similarity index 80% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestAllLibnd4j.java index 8a77be345..288093989 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestAllLibnd4j.java @@ -25,11 +25,8 @@ import lombok.val; import org.junit.jupiter.api.*;import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; -import org.junit.rules.TestWatcher; -import org.junit.runner.Description; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.OpValidationSuite; +import org.junit.jupiter.params.provider.Arguments; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; @@ -40,8 +37,9 @@ import org.nd4j.nativeblas.NativeOpsHolder; import java.io.File; import java.io.IOException; import java.util.*; +import java.util.stream.Stream; + -@RunWith(Parameterized.class) @Slf4j @Disabled("AB 2019/05/21 - JVM Crashes - Issue #7657") public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests @@ -115,52 +113,36 @@ public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); } - @Parameterized.Parameters(name="{2}") - public static Collection data() throws IOException { + + public static Stream data() throws IOException { val localPath = System.getenv(TFGraphTestAllHelper.resourceFolderVar); // if this variable isn't set - we're using dl4j-tests-resources if (localPath == null) { File baseDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); - return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); + return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir).stream().map(Arguments::of); } else { File baseDir = new File(localPath); - return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); + return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir).stream().map(Arguments::of); } } - public TFGraphTestAllLibnd4j(Map inputs, Map predictions, String modelName, File localTestDir) { - this.inputs = inputs; - this.predictions = predictions; - this.modelName = modelName; - this.localTestDir = localTestDir; - } @Test//(timeout = 25000L) public void test() throws Exception { - if(TFGraphTestZooModels.isPPC()){ - /* - Ugly hack to temporarily disable tests on PPC only on CI - Issue logged here: https://github.com/eclipse/deeplearning4j/issues/7657 - These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions - */ - - log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/eclipse/deeplearning4j/issues/7657"); - OpValidationSuite.ignoreFailing(); - } Nd4j.create(1); for(String s : TFGraphTestAllSameDiff.IGNORE_REGEXES){ if(modelName.matches(s)){ log.info("\n\tIGNORE MODEL ON REGEX: {} - regex {}", modelName, s); - OpValidationSuite.ignoreFailing(); + //OpValidationSuite.ignoreFailing(); } } for(String s : SKIP_FOR_LIBND4J_EXEC){ if(modelName.matches(s)){ log.info("\n\tIGNORE MODEL ON REGEX - SKIP LIBND4J EXEC ONLY: {} - regex {}", modelName, s); - OpValidationSuite.ignoreFailing(); + //OpValidationSuite.ignoreFailing(); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestAllSameDiff.java similarity index 89% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestAllSameDiff.java index c2a916d42..1a7772fee 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestAllSameDiff.java @@ -23,10 +23,9 @@ package org.nd4j.imports.tfgraphs; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.jupiter.api.*; -import org.junit.runner.Description; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.OpValidationSuite; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; @@ -36,10 +35,9 @@ import org.nd4j.common.primitives.Pair; import java.io.File; import java.io.IOException; import java.util.*; +import java.util.stream.Stream; @Slf4j -@RunWith(Parameterized.class) -@Disabled public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests @@ -161,18 +159,17 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a public void tearDown() { } - @Parameterized.Parameters(name="{2}") - public static Collection data() throws IOException { + public static Stream data() throws IOException { val localPath = System.getenv(TFGraphTestAllHelper.resourceFolderVar); // if this variable isn't set - we're using dl4j-tests-resources if (localPath == null) { File baseDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); List params = TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); - return params; + return params.stream().map(input -> Arguments.of(input)); } else { File baseDir = new File(localPath); - return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); + return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir).stream().map(input -> Arguments.of(input)); } } @@ -184,30 +181,20 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a } @Test//(timeout = 25000L) + @ParameterizedTest public void testOutputOnly() throws Exception { - if(TFGraphTestZooModels.isPPC()) { - /* - Ugly hack to temporarily disable tests on PPC only on CI - Issue logged here: https://github.com/eclipse/deeplearning4j/issues/7657 - These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions - */ - - log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/eclipse/deeplearning4j/issues/7657"); - OpValidationSuite.ignoreFailing(); - } - Nd4j.create(1); if(EXECUTE_ONLY_MODELS.isEmpty()) { for(String s : IGNORE_REGEXES) { if(modelName.matches(s)) { log.info("\n\tIGNORE MODEL ON REGEX: {} - regex {}", modelName, s); - OpValidationSuite.ignoreFailing(); + //OpValidationSuite.ignoreFailing(); } } } else if(!EXECUTE_ONLY_MODELS.contains(modelName)) { log.info("Not executing " + modelName); - OpValidationSuite.ignoreFailing(); + //OpValidationSuite.ignoreFailing(); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestList.java similarity index 87% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestList.java index 455734817..81c34c72a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestList.java @@ -25,8 +25,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.io.TempDir; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.factory.Nd4j; @@ -37,11 +39,11 @@ import java.io.File; import java.io.IOException; import java.nio.file.Path; import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.stream.Stream; + -@RunWith(Parameterized.class) @Disabled public class TFGraphTestList { @@ -75,22 +77,21 @@ public class TFGraphTestList { private String modelName; - @Parameterized.Parameters - public static Collection data() { + + public static Stream data() { List modelNamesParams = new ArrayList<>(); for (int i = 0; i < modelNames.length; i++) { Object[] currentParams = new String[]{modelNames[i]}; modelNamesParams.add(currentParams); } - return modelNamesParams; + return modelNamesParams.stream().map(Arguments::of); } - public TFGraphTestList(String modelName) { - this.modelName = modelName; - } @Test - public void testOutputOnly(@TempDir Path testDir) throws IOException { + @ParameterizedTest + @MethodSource("#data") + public void testOutputOnly(@TempDir Path testDir,String modelName) throws IOException { //Nd4jCpu.Environment.getInstance().setUseMKLDNN(false); File dir = testDir.toFile(); Map inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir); @@ -104,7 +105,9 @@ public class TFGraphTestList { } @Test @Disabled - public void testAlsoIntermediate(@TempDir Path testDir) throws IOException { + @ParameterizedTest + @MethodSource("#data") + public void testAlsoIntermediate(@TempDir Path testDir,String modelName) throws IOException { //Nd4jCpu.Environment.getInstance().setUseMKLDNN(false); File dir = testDir.toFile(); Map inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestZooModels.java similarity index 94% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestZooModels.java index f5e0f1130..9300a56aa 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphTestZooModels.java @@ -27,9 +27,10 @@ import org.apache.commons.lang3.ArrayUtils; import org.junit.jupiter.api.*; import org.junit.jupiter.api.io.TempDir; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.OpValidationSuite; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; @@ -46,11 +47,11 @@ import java.net.URL; import java.nio.charset.StandardCharsets; import java.nio.file.Path; import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.stream.Stream; + -@RunWith(Parameterized.class) @Slf4j @Disabled public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests @@ -211,19 +212,11 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); } - @Parameterized.Parameters(name="{2}") - public static Collection data() throws IOException { + public static Stream data() throws IOException { classTestDir.toFile().mkdir(); File baseDir = classTestDir.toFile(); // new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); List params = TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF, baseDir); - return params; - } - - public TFGraphTestZooModels(Map inputs, Map predictions, String modelName, File localTestDir) { - this.inputs = inputs; - this.predictions = predictions; - this.modelName = modelName; - this.localTestDir = localTestDir; + return params.stream().map(Arguments::of); } private static Boolean isPPC = null; @@ -240,6 +233,8 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we } @Test //(timeout = 360000L) + @ParameterizedTest + @MethodSource("#data") public void testOutputOnly(@TempDir Path testDir) throws Exception { if(isPPC()){ /* @@ -249,7 +244,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we */ log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/eclipse/deeplearning4j/issues/7657"); - OpValidationSuite.ignoreFailing(); + //OpValidationSuite.ignoreFailing(); } // if(!modelName.startsWith("ssd_mobilenet_v1_coco_2018_01_28")){ @@ -265,7 +260,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we Nd4j.create(1); if(ArrayUtils.contains(IGNORE_REGEXES, modelName)){ log.info("\n\tIGNORE MODEL ON REGEX: {} - regex {}", modelName, modelName); - OpValidationSuite.ignoreFailing(); + // OpValidationSuite.ignoreFailing(); } Double maxRE = 1e-3; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphsSkipNodes.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphsSkipNodes.java similarity index 100% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphsSkipNodes.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/TFGraphsSkipNodes.java diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/ValidateZooModelPredictions.java similarity index 87% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/ValidateZooModelPredictions.java index 17a2cd3b2..a161e24e2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/ValidateZooModelPredictions.java +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/ValidateZooModelPredictions.java @@ -28,9 +28,10 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; -import org.nd4j.OpValidationSuite; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -47,11 +48,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j @Disabled -public class ValidateZooModelPredictions extends BaseNd4jTest { +public class ValidateZooModelPredictions extends BaseNd4jTestWithBackends { - public ValidateZooModelPredictions(Nd4jBackend backend) { - super(backend); - } @Override public char ordering() { @@ -73,18 +71,9 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { } @Test - public void testMobilenetV1(@TempDir Path testDir) throws Exception { - if(TFGraphTestZooModels.isPPC()){ - /* - Ugly hack to temporarily disable tests on PPC only on CI - Issue logged here: https://github.com/eclipse/deeplearning4j/issues/7657 - These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions - */ - - log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/eclipse/deeplearning4j/issues/7657"); - OpValidationSuite.ignoreFailing(); - } - + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testMobilenetV1(@TempDir Path testDir,Nd4jBackend backend) throws Exception { TFGraphTestZooModels.currentTestDir = testDir.toFile(); //Load model @@ -138,7 +127,9 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { @Test - public void testResnetV2(@TempDir Path testDir) throws Exception { + @ParameterizedTest + @MethodSource("org.nd4j.linalg.BaseNd4jTest#configs") + public void testResnetV2(@TempDir Path testDir,Nd4jBackend backend) throws Exception { if(TFGraphTestZooModels.isPPC()){ /* Ugly hack to temporarily disable tests on PPC only on CI @@ -147,7 +138,7 @@ public class ValidateZooModelPredictions extends BaseNd4jTest { */ log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/eclipse/deeplearning4j/issues/7657"); - OpValidationSuite.ignoreFailing(); + //OpValidationSuite.ignoreFailing(); } TFGraphTestZooModels.currentTestDir = testDir.toFile(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/listener/OpExecOrderListener.java b/nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/listener/OpExecOrderListener.java similarity index 100% rename from nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/listener/OpExecOrderListener.java rename to nd4j/samediff-import/samediff-import-tensorflow/src/test/java/org/nd4j/imports/tfgraphs/listener/OpExecOrderListener.java diff --git a/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt b/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt index 96303b8f3..2c711166c 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt +++ b/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt @@ -10307,6 +10307,41 @@ mappings { inputFrameworkOpName: "UniqueWithCounts" } } +mappings { + frameworkName: "tensorflow" + opName: "ctc_loss" + inputFrameworkOpName: "CTCLoss" + rule { + ruleName: "ndarraymapping" + functionName: "ndarraymapping" + inputTensorName: "inputs" + inputTensorName: "labels_values" + inputTensorName: "labels_indices" + inputTensorName: "sequence_length" + outputTensorName: "logitInput" + outputTensorName: "targetLabels" + outputTensorName: "targetLabelLengths" + outputTensorName: "logitInputLengths" + inputToOutput { + key: "logitInput" + value: "inputs" + } + inputToOutput { + key: "targetLabels" + value: "labels_values" + } + inputToOutput { + key: "targetLabelLengths" + value: "labels_indices" + } + inputToOutput { + key: "logitInputLengths" + value: "sequence_length" + } + ruleType: "tensor" + inputFrameworkOpName: "CTCLoss" + } +} mappings { frameworkName: "tensorflow" opName: "randomuniform" diff --git a/pom.xml b/pom.xml index bf2503468..6080a69dc 100644 --- a/pom.xml +++ b/pom.xml @@ -327,6 +327,13 @@ + + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + org.junit.jupiter junit-jupiter-api diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java index 0332c6d94..85c319eb9 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java @@ -19,11 +19,13 @@ */ +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.python4j.*; import org.junit.Assert; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -35,20 +37,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.stream.Stream; @NotThreadSafe -@RunWith(Parameterized.class) public class PythonNumpyBasicTest { - private DataType dataType; - private long[] shape; - - public PythonNumpyBasicTest(DataType dataType, long[] shape, String dummyArg) { - this.dataType = dataType; - this.shape = shape; - } - - @Parameterized.Parameters(name = "{index}: Testing with DataType={0}, shape={2}") - public static Collection params() { + public static Stream params() { DataType[] types = new DataType[] { DataType.BOOL, DataType.FLOAT16, @@ -79,11 +72,13 @@ public class PythonNumpyBasicTest { ret.add(new Object[]{type, shape, Arrays.toString(shape)}); } } - return ret; + return ret.stream().map(Arguments::of); } @Test - public void testConversion(){ + @ParameterizedTest + @MethodSource("#params") + public void testConversion(DataType dataType,long[] shape){ try(PythonGIL pythonGIL = PythonGIL.lock()) { INDArray arr = Nd4j.zeros(dataType, shape); PythonObject npArr = PythonTypes.convert(arr); @@ -98,7 +93,9 @@ public class PythonNumpyBasicTest { @Test - public void testExecution() { + @ParameterizedTest + @MethodSource("#params") + public void testExecution(DataType dataType,long[] shape) { try(PythonGIL pythonGIL = PythonGIL.lock()) { List inputs = new ArrayList<>(); INDArray x = Nd4j.ones(dataType, shape); @@ -127,7 +124,9 @@ public class PythonNumpyBasicTest { @Test - public void testInplaceExecution() { + @ParameterizedTest + @MethodSource("#params") + public void testInplaceExecution(DataType dataType,long[] shape) { try(PythonGIL pythonGIL = PythonGIL.lock()) { if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return; if (shape.length == 0) return; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java index 2dbe8305c..c3198c19f 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java @@ -19,33 +19,30 @@ */ +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.python4j.PythonException; import org.nd4j.python4j.PythonGIL; import org.nd4j.python4j.PythonObject; import org.nd4j.python4j.PythonTypes; import org.junit.Assert; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import javax.annotation.concurrent.NotThreadSafe; import java.util.*; +import java.util.stream.Stream; @NotThreadSafe -@RunWith(Parameterized.class) public class PythonNumpyCollectionsTest { - private DataType dataType; - public PythonNumpyCollectionsTest(DataType dataType){ - this.dataType = dataType; - } - @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") - public static DataType[] params() { - return new DataType[]{ + public static Stream params() { + return Arrays.asList(new DataType[]{ DataType.BOOL, DataType.FLOAT16, //DataType.BFLOAT16, @@ -59,10 +56,13 @@ public class PythonNumpyCollectionsTest { DataType.UINT16, DataType.UINT32, DataType.UINT64 - }; + }).stream().map(Arguments::of); } + @Test - public void testPythonDictFromMap() throws PythonException { + @MethodSource("#params") + @ParameterizedTest + public void testPythonDictFromMap(DataType dataType) throws PythonException { try(PythonGIL pythonGIL = PythonGIL.lock()) { Map map = new HashMap(); map.put("a", 1); @@ -83,7 +83,9 @@ public class PythonNumpyCollectionsTest { } @Test - public void testPythonListFromList() throws PythonException { + @MethodSource("#params") + @ParameterizedTest + public void testPythonListFromList(DataType dataType) throws PythonException { try(PythonGIL pythonGIL = PythonGIL.lock()) { List list = new ArrayList<>(); list.add(1); diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java index 17a794015..3f64e8678 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java @@ -18,11 +18,13 @@ * ***************************************************************************** */ +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.nd4j.python4j.*; import org.junit.Assert; import org.junit.jupiter.api.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; + import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,20 +34,14 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.stream.Stream; @NotThreadSafe -@RunWith(Parameterized.class) public class PythonNumpyMultiThreadTest { - private DataType dataType; - public PythonNumpyMultiThreadTest(DataType dataType) { - this.dataType = dataType; - } - - @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") - public static DataType[] params() { - return new DataType[]{ + public static Stream params() { + return Arrays.asList(new DataType[]{ // DataType.BOOL, // DataType.FLOAT16, // DataType.BFLOAT16, @@ -59,29 +55,28 @@ public class PythonNumpyMultiThreadTest { // DataType.UINT16, // DataType.UINT32, // DataType.UINT64 - }; + }).stream().map(Arguments::of); } @Test - public void testMultiThreading1() throws Throwable { + @MethodSource("#params") + @ParameterizedTest + public void testMultiThreading1(DataType dataType) throws Throwable { final List exceptions = Collections.synchronizedList(new ArrayList()); - Runnable runnable = new Runnable() { - @Override - public void run() { - try (PythonGIL gil = PythonGIL.lock()) { - try (PythonGC gc = PythonGC.watch()) { - List inputs = new ArrayList<>(); - inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); - inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); - PythonVariable out = new PythonVariable<>("z", NumpyArray.INSTANCE); - String code = "z = x + y"; - PythonExecutioner.exec(code, inputs, Collections.singletonList(out)); - Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), out.getValue()); - } - } catch (Throwable e) { - exceptions.add(e); + Runnable runnable = () -> { + try (PythonGIL gil = PythonGIL.lock()) { + try (PythonGC gc = PythonGC.watch()) { + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + PythonVariable out = new PythonVariable<>("z", NumpyArray.INSTANCE); + String code = "z = x + y"; + PythonExecutioner.exec(code, inputs, Collections.singletonList(out)); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), out.getValue()); } + } catch (Throwable e) { + exceptions.add(e); } }; @@ -104,8 +99,10 @@ public class PythonNumpyMultiThreadTest { } @Test - public void testMultiThreading2() throws Throwable { - final List exceptions = Collections.synchronizedList(new ArrayList()); + @MethodSource("#params") + @ParameterizedTest + public void testMultiThreading2(DataType dataType) throws Throwable { + final List exceptions = Collections.synchronizedList(new ArrayList<>()); Runnable runnable = new Runnable() { @Override public void run() {