Migrate parameterized tests to junit 5
parent
82bdcc21d2
commit
3c6014271e
|
@ -37,8 +37,10 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
|||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -47,13 +49,14 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.learning.config.NoOp;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.deeplearning4j.nn.conf.ConvolutionMode.Same;
|
||||
import static org.deeplearning4j.nn.conf.ConvolutionMode.Truncate;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
@DisplayName("Cnn Gradient Check Test")
|
||||
class CNNGradientCheckTest extends BaseDL4JTest {
|
||||
|
||||
|
@ -71,15 +74,10 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
Nd4j.setDataType(DataType.DOUBLE);
|
||||
}
|
||||
|
||||
private CNN2DFormat format;
|
||||
|
||||
public CNNGradientCheckTest(CNN2DFormat format) {
|
||||
this.format = format;
|
||||
}
|
||||
|
||||
@Parameterized.Parameters(name = "{0}")
|
||||
public static Object[] params() {
|
||||
return CNN2DFormat.values();
|
||||
public static Stream<Arguments> params() {
|
||||
return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -89,9 +87,11 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Gradient CNNMLN")
|
||||
void testGradientCNNMLN() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testGradientCNNMLN(CNN2DFormat format) {
|
||||
if (// Only test NCHW due to flat input format...
|
||||
this.format != CNN2DFormat.NCHW)
|
||||
format != CNN2DFormat.NCHW)
|
||||
return;
|
||||
// Parameterized test, testing combinations of:
|
||||
// (a) activation function
|
||||
|
@ -146,9 +146,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Gradient CNNL 1 L 2 MLN")
|
||||
void testGradientCNNL1L2MLN() {
|
||||
void testGradientCNNL1L2MLN(CNN2DFormat format) {
|
||||
if (// Only test NCHW due to flat input format...
|
||||
this.format != CNN2DFormat.NCHW)
|
||||
format != CNN2DFormat.NCHW)
|
||||
return;
|
||||
// Parameterized test, testing combinations of:
|
||||
// (a) activation function
|
||||
|
@ -245,7 +245,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Cnn With Space To Batch")
|
||||
void testCnnWithSpaceToBatch() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testCnnWithSpaceToBatch(CNN2DFormat format) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nOut = 4;
|
||||
int[] minibatchSizes = { 2, 4 };
|
||||
|
@ -289,7 +291,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Cnn With Upsampling")
|
||||
void testCnnWithUpsampling() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testCnnWithUpsampling(CNN2DFormat format) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nOut = 4;
|
||||
int[] minibatchSizes = { 1, 3 };
|
||||
|
@ -323,7 +327,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Cnn With Subsampling")
|
||||
void testCnnWithSubsampling() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testCnnWithSubsampling(CNN2DFormat format) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nOut = 4;
|
||||
int[] minibatchSizes = { 1, 3 };
|
||||
|
@ -365,7 +371,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Cnn With Subsampling V 2")
|
||||
void testCnnWithSubsamplingV2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testCnnWithSubsamplingV2(CNN2DFormat format) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nOut = 4;
|
||||
int[] minibatchSizes = { 1, 3 };
|
||||
|
@ -403,7 +411,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Cnn Locally Connected 2 D")
|
||||
void testCnnLocallyConnected2D() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testCnnLocallyConnected2D(CNN2DFormat format) {
|
||||
int nOut = 3;
|
||||
int width = 5;
|
||||
int height = 5;
|
||||
|
@ -433,7 +443,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Cnn Multi Layer")
|
||||
void testCnnMultiLayer() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testCnnMultiLayer(CNN2DFormat format) {
|
||||
int nOut = 2;
|
||||
int[] minibatchSizes = { 1, 2, 5 };
|
||||
int width = 5;
|
||||
|
@ -473,7 +485,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Cnn Same Padding Mode")
|
||||
void testCnnSamePaddingMode() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testCnnSamePaddingMode(CNN2DFormat format) {
|
||||
int nOut = 2;
|
||||
int[] minibatchSizes = { 1, 3, 3, 2, 1, 2 };
|
||||
// Same padding mode: insensitive to exact input size...
|
||||
|
@ -507,7 +521,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Cnn Same Padding Mode Strided")
|
||||
void testCnnSamePaddingModeStrided() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testCnnSamePaddingModeStrided(CNN2DFormat format) {
|
||||
int nOut = 2;
|
||||
int[] minibatchSizes = { 1, 3 };
|
||||
int width = 16;
|
||||
|
@ -550,7 +566,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Cnn Zero Padding Layer")
|
||||
void testCnnZeroPaddingLayer() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testCnnZeroPaddingLayer(CNN2DFormat format) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nOut = 4;
|
||||
int width = 6;
|
||||
|
@ -596,7 +614,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Deconvolution 2 D")
|
||||
void testDeconvolution2D() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testDeconvolution2D(CNN2DFormat format) {
|
||||
int nOut = 2;
|
||||
int[] minibatchSizes = new int[] { 1, 3, 3, 1, 3 };
|
||||
int[] kernelSizes = new int[] { 1, 1, 1, 3, 3 };
|
||||
|
@ -641,7 +661,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Separable Conv 2 D")
|
||||
void testSeparableConv2D() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testSeparableConv2D(CNN2DFormat format) {
|
||||
int nOut = 2;
|
||||
int width = 6;
|
||||
int height = 6;
|
||||
|
@ -686,7 +708,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Cnn Dilated")
|
||||
void testCnnDilated() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testCnnDilated(CNN2DFormat format) {
|
||||
int nOut = 2;
|
||||
int minibatchSize = 2;
|
||||
int width = 8;
|
||||
|
@ -736,7 +760,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Cropping 2 D Layer")
|
||||
void testCropping2DLayer() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testCropping2DLayer(CNN2DFormat format) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nOut = 2;
|
||||
int width = 12;
|
||||
|
@ -780,7 +806,9 @@ class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Depthwise Conv 2 D")
|
||||
void testDepthwiseConv2D() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testDepthwiseConv2D(CNN2DFormat format) {
|
||||
int nIn = 3;
|
||||
int depthMultiplier = 2;
|
||||
int nOut = nIn * depthMultiplier;
|
||||
|
|
|
@ -39,8 +39,10 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -55,26 +57,22 @@ import java.io.File;
|
|||
import java.io.FileOutputStream;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||
|
||||
static {
|
||||
Nd4j.setDataType(DataType.DOUBLE);
|
||||
}
|
||||
|
||||
private CNN2DFormat format;
|
||||
public YoloGradientCheckTests(CNN2DFormat format){
|
||||
this.format = format;
|
||||
}
|
||||
@Parameterized.Parameters(name = "{0}")
|
||||
public static Object[] params(){
|
||||
return CNN2DFormat.values();
|
||||
}
|
||||
|
||||
public static Stream<Arguments> params() {
|
||||
return Arrays.asList(CNN2DFormat.values()).stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
|
@ -82,7 +80,9 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testYoloOutputLayer() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testYoloOutputLayer(CNN2DFormat format) {
|
||||
int depthIn = 2;
|
||||
int c = 3;
|
||||
int b = 3;
|
||||
|
@ -159,13 +159,13 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
private static INDArray yoloLabels(int mb, int c, int h, int w){
|
||||
private static INDArray yoloLabels(int mb, int c, int h, int w) {
|
||||
int labelDepth = 4 + c;
|
||||
INDArray labels = Nd4j.zeros(mb, labelDepth, h, w);
|
||||
//put 1 object per minibatch, at positions (0,0), (1,1) etc.
|
||||
//Positions for label boxes: (1,1) to (2,2), (2,2) to (4,4) etc
|
||||
|
||||
for( int i=0; i<mb; i++ ){
|
||||
for( int i = 0; i < mb; i++) {
|
||||
//Class labels
|
||||
labels.putScalar(i, 4 + i%c, i%h, i%w, 1);
|
||||
|
||||
|
@ -181,7 +181,9 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void yoloGradientCheckRealData(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void yoloGradientCheckRealData(@TempDir Path testDir,CNN2DFormat format) throws Exception {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
InputStream is1 = new ClassPathResource("yolo/VOC_TwoImage/JPEGImages/2007_009346.jpg").getInputStream();
|
||||
InputStream is2 = new ClassPathResource("yolo/VOC_TwoImage/Annotations/2007_009346.xml").getInputStream();
|
||||
|
|
|
@ -39,8 +39,10 @@ import org.deeplearning4j.nn.workspace.ArrayType;
|
|||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.util.ConvolutionUtils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -48,24 +50,19 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class ConvDataFormatTests extends BaseDL4JTest {
|
||||
|
||||
private final DataType dataType;
|
||||
|
||||
public ConvDataFormatTests(DataType dataType){
|
||||
this.dataType = dataType;
|
||||
}
|
||||
|
||||
@Parameterized.Parameters(name = "{0}")
|
||||
public static Object[] params(){
|
||||
return new DataType[]{DataType.FLOAT, DataType.DOUBLE};
|
||||
public static Stream<Arguments> params(){
|
||||
return Arrays.asList(new DataType[]{DataType.FLOAT, DataType.DOUBLE}).stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -74,7 +71,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testConv2d() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testConv2d(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||
|
@ -83,15 +82,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getConv2dNet(CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getConv2dNet(CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getConv2dNet(CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getConv2dNet(CNN2DFormat.NHWC, false, cm))
|
||||
.net1(getConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -107,7 +106,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSubsampling2d() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testSubsampling2d(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||
|
@ -116,15 +117,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm))
|
||||
.net1(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getSubsampling2dNet(dataType,CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getSubsampling2dNet(dataType,CNN2DFormat.NHWC, false, cm))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -140,7 +141,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDepthwiseConv2d() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testDepthwiseConv2d(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||
|
@ -149,15 +152,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm))
|
||||
.net1(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getDepthwiseConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getDepthwiseConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -173,7 +176,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSeparableConv2d() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testSeparableConv2d(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||
|
@ -182,15 +187,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm))
|
||||
.net1(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getSeparableConv2dNet(dataType,CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getSeparableConv2dNet(dataType,CNN2DFormat.NHWC, false, cm))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -206,7 +211,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDeconv2d() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testDeconv2d(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||
|
@ -215,15 +222,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm))
|
||||
.net1(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getDeconv2DNet2dNet(dataType,CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getDeconv2DNet2dNet(dataType,CNN2DFormat.NHWC, false, cm))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -239,7 +246,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLRN() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testLRN(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||
|
@ -248,15 +257,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getLrnLayer(CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getLrnLayer(CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getLrnLayer(CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getLrnLayer(CNN2DFormat.NHWC, false, cm))
|
||||
.net1(getLrnLayer(dataType,CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getLrnLayer(dataType,CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getLrnLayer(dataType,CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getLrnLayer(dataType,CNN2DFormat.NHWC, false, cm))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -272,7 +281,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testZeroPaddingLayer(){
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testZeroPaddingLayer(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -280,15 +291,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers" : "No helpers";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getZeroPaddingNet(CNN2DFormat.NCHW, true))
|
||||
.net2(getZeroPaddingNet(CNN2DFormat.NCHW, false))
|
||||
.net3(getZeroPaddingNet(CNN2DFormat.NHWC, true))
|
||||
.net4(getZeroPaddingNet(CNN2DFormat.NHWC, false))
|
||||
.net1(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, true))
|
||||
.net2(getZeroPaddingNet(dataType,CNN2DFormat.NCHW, false))
|
||||
.net3(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, true))
|
||||
.net4(getZeroPaddingNet(dataType,CNN2DFormat.NHWC, false))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -303,7 +314,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCropping2DLayer(){
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testCropping2DLayer(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -311,15 +324,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers" : "No helpers";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getCropping2dNet(CNN2DFormat.NCHW, true))
|
||||
.net2(getCropping2dNet(CNN2DFormat.NCHW, false))
|
||||
.net3(getCropping2dNet(CNN2DFormat.NHWC, true))
|
||||
.net4(getCropping2dNet(CNN2DFormat.NHWC, false))
|
||||
.net1(getCropping2dNet(dataType,CNN2DFormat.NCHW, true))
|
||||
.net2(getCropping2dNet(dataType,CNN2DFormat.NCHW, false))
|
||||
.net3(getCropping2dNet(dataType,CNN2DFormat.NHWC, true))
|
||||
.net4(getCropping2dNet(dataType,CNN2DFormat.NHWC, false))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -334,7 +347,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testUpsampling2d(){
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testUpsampling2d(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -342,15 +357,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers" : "No helpers";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getUpsamplingNet(CNN2DFormat.NCHW, true))
|
||||
.net2(getUpsamplingNet(CNN2DFormat.NCHW, false))
|
||||
.net3(getUpsamplingNet(CNN2DFormat.NHWC, true))
|
||||
.net4(getUpsamplingNet(CNN2DFormat.NHWC, false))
|
||||
.net1(getUpsamplingNet(dataType,CNN2DFormat.NCHW, true))
|
||||
.net2(getUpsamplingNet(dataType,CNN2DFormat.NCHW, false))
|
||||
.net3(getUpsamplingNet(dataType,CNN2DFormat.NHWC, true))
|
||||
.net4(getUpsamplingNet(dataType,CNN2DFormat.NHWC, false))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -365,7 +380,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBatchNormNet(){
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testBatchNormNet(DataType dataType) {
|
||||
try {
|
||||
for(boolean useLogStd : new boolean[]{true, false}) {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
|
@ -374,15 +391,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std");
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true))
|
||||
.net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false))
|
||||
.net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true))
|
||||
.net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false))
|
||||
.net1(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, true))
|
||||
.net2(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NCHW, false))
|
||||
.net3(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, true))
|
||||
.net4(getBatchNormNet(dataType,useLogStd, CNN2DFormat.NHWC, false))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -398,7 +415,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCnnLossLayer() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testCnnLossLayer(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -406,8 +425,8 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers" : "No helpers";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labelsNHWC = TestUtils.randomOneHot(dataType,2*6*6, 3);
|
||||
labelsNHWC = labelsNHWC.reshape(2,6,6,3);
|
||||
INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup();
|
||||
|
||||
|
@ -434,7 +453,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSpaceToDepthNet(){
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testSpaceToDepthNet(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -442,15 +463,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers" : "No helpers";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true))
|
||||
.net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false))
|
||||
.net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true))
|
||||
.net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false))
|
||||
.net1(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, true))
|
||||
.net2(getSpaceToDepthNet(dataType,CNN2DFormat.NCHW, false))
|
||||
.net3(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, true))
|
||||
.net4(getSpaceToDepthNet(dataType,CNN2DFormat.NHWC, false))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -465,7 +486,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSpaceToBatchNet(){
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testSpaceToBatchNet(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -473,15 +496,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers" : "No helpers";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 16, 16);
|
||||
INDArray labels = TestUtils.randomOneHot(8, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true))
|
||||
.net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false))
|
||||
.net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true))
|
||||
.net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false))
|
||||
.net1(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, true))
|
||||
.net2(getSpaceToBatchNet(dataType,CNN2DFormat.NCHW, false))
|
||||
.net3(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, true))
|
||||
.net4(getSpaceToBatchNet(dataType,CNN2DFormat.NHWC, false))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -496,7 +519,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLocallyConnected() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testLocallyConnected(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) {
|
||||
|
@ -505,15 +530,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm))
|
||||
.net1(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, true, cm))
|
||||
.net2(getLocallyConnectedNet(dataType,CNN2DFormat.NCHW, false, cm))
|
||||
.net3(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, true, cm))
|
||||
.net4(getLocallyConnectedNet(dataType,CNN2DFormat.NHWC, false, cm))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -530,7 +555,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testGlobalPooling() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testGlobalPooling(DataType dataType) {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
for (PoolingType pt : PoolingType.values()) {
|
||||
|
@ -539,15 +566,15 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray inNCHW = Nd4j.rand(dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true))
|
||||
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false))
|
||||
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true))
|
||||
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false))
|
||||
.net1(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, true))
|
||||
.net2(getGlobalPoolingNet(dataType,CNN2DFormat.NCHW, pt, false))
|
||||
.net3(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, true))
|
||||
.net4(getGlobalPoolingNet(dataType,CNN2DFormat.NHWC, pt, false))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
|
@ -562,9 +589,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
private MultiLayerNetwork getConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new ConvolutionLayer.Builder()
|
||||
return getNetWithLayer(dataType,new ConvolutionLayer.Builder()
|
||||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
.activation(Activation.TANH)
|
||||
|
@ -573,7 +600,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
.helperAllowFallback(false)
|
||||
.build(), format, cm, null);
|
||||
} else {
|
||||
return getNetWithLayer(new ConvolutionLayer.Builder()
|
||||
return getNetWithLayer(dataType,new ConvolutionLayer.Builder()
|
||||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
.activation(Activation.TANH)
|
||||
|
@ -583,16 +610,16 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
private MultiLayerNetwork getSubsampling2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new SubsamplingLayer.Builder()
|
||||
return getNetWithLayer(dataType,new SubsamplingLayer.Builder()
|
||||
.kernelSize(2, 2)
|
||||
.stride(1, 1)
|
||||
.dataFormat(format)
|
||||
.helperAllowFallback(false)
|
||||
.build(), format, cm, null);
|
||||
} else {
|
||||
return getNetWithLayer(new SubsamplingLayer.Builder()
|
||||
return getNetWithLayer(dataType,new SubsamplingLayer.Builder()
|
||||
.kernelSize(2, 2)
|
||||
.stride(1, 1)
|
||||
.helperAllowFallback(false)
|
||||
|
@ -600,9 +627,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
private MultiLayerNetwork getSeparableConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new SeparableConvolution2D.Builder()
|
||||
return getNetWithLayer(dataType,new SeparableConvolution2D.Builder()
|
||||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
.activation(Activation.TANH)
|
||||
|
@ -611,7 +638,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
.helperAllowFallback(false)
|
||||
.build(), format, cm, null);
|
||||
} else {
|
||||
return getNetWithLayer(new SeparableConvolution2D.Builder()
|
||||
return getNetWithLayer(dataType,new SeparableConvolution2D.Builder()
|
||||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
.activation(Activation.TANH)
|
||||
|
@ -621,9 +648,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
private MultiLayerNetwork getDepthwiseConv2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
|
||||
return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder()
|
||||
.depthMultiplier(2)
|
||||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
|
@ -633,7 +660,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
.helperAllowFallback(false)
|
||||
.build(), format, cm, null);
|
||||
} else {
|
||||
return getNetWithLayer(new DepthwiseConvolution2D.Builder()
|
||||
return getNetWithLayer(dataType,new DepthwiseConvolution2D.Builder()
|
||||
.depthMultiplier(2)
|
||||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
|
@ -644,59 +671,59 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
private MultiLayerNetwork getLrnLayer(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new LocalResponseNormalization.Builder()
|
||||
return getNetWithLayer(dataType,new LocalResponseNormalization.Builder()
|
||||
.dataFormat(format)
|
||||
.helperAllowFallback(false)
|
||||
.build(), format, cm, null);
|
||||
} else {
|
||||
return getNetWithLayer(new LocalResponseNormalization.Builder()
|
||||
return getNetWithLayer(dataType,new LocalResponseNormalization.Builder()
|
||||
.helperAllowFallback(false)
|
||||
.build(), format, cm, null);
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
private MultiLayerNetwork getZeroPaddingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2)
|
||||
return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2)
|
||||
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
|
||||
} else {
|
||||
return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(),
|
||||
return getNetWithLayer(dataType,new ZeroPaddingLayer.Builder(2,2).build(),
|
||||
format, ConvolutionMode.Same, null);
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
private MultiLayerNetwork getCropping2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new Cropping2D.Builder(2,2)
|
||||
return getNetWithLayer(dataType,new Cropping2D.Builder(2,2)
|
||||
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
|
||||
} else {
|
||||
return getNetWithLayer(new Cropping2D.Builder(2,2)
|
||||
return getNetWithLayer(dataType,new Cropping2D.Builder(2,2)
|
||||
.build(), format, ConvolutionMode.Same, null);
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
private MultiLayerNetwork getUpsamplingNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new Upsampling2D.Builder(2)
|
||||
return getNetWithLayer(dataType,new Upsampling2D.Builder(2)
|
||||
.dataFormat(format).build(), format, ConvolutionMode.Same, null);
|
||||
} else {
|
||||
return getNetWithLayer(new Upsampling2D.Builder(2)
|
||||
return getNetWithLayer(dataType,new Upsampling2D.Builder(2)
|
||||
.build(), format, ConvolutionMode.Same, null);
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
private MultiLayerNetwork getDeconv2DNet2dNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
|
||||
return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2)
|
||||
.activation(Activation.TANH)
|
||||
.kernelSize(2,2)
|
||||
.dataFormat(format)
|
||||
.stride(2,2)
|
||||
.build(), format, cm, null);
|
||||
} else {
|
||||
return getNetWithLayer(new Deconvolution2D.Builder().nOut(2)
|
||||
return getNetWithLayer(dataType,new Deconvolution2D.Builder().nOut(2)
|
||||
.activation(Activation.TANH)
|
||||
.kernelSize(2,2)
|
||||
.dataFormat(format)
|
||||
|
@ -705,50 +732,50 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
private MultiLayerNetwork getBatchNormNet(DataType dataType,boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new BatchNormalization.Builder()
|
||||
return getNetWithLayer(dataType,new BatchNormalization.Builder()
|
||||
.useLogStd(logStdev)
|
||||
.dataFormat(format)
|
||||
.helperAllowFallback(false)
|
||||
.nOut(3).build(), format, ConvolutionMode.Same, null);
|
||||
} else {
|
||||
return getNetWithLayer(new BatchNormalization.Builder()
|
||||
return getNetWithLayer(dataType,new BatchNormalization.Builder()
|
||||
.useLogStd(logStdev)
|
||||
.helperAllowFallback(false)
|
||||
.nOut(3).build(), format, ConvolutionMode.Same, null);
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
private MultiLayerNetwork getSpaceToDepthNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new SpaceToDepthLayer.Builder()
|
||||
return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder()
|
||||
.blocks(2)
|
||||
.dataFormat(format)
|
||||
.build(), format, ConvolutionMode.Same, null);
|
||||
} else {
|
||||
return getNetWithLayer(new SpaceToDepthLayer.Builder()
|
||||
return getNetWithLayer(dataType,new SpaceToDepthLayer.Builder()
|
||||
.blocks(2)
|
||||
.build(), format, ConvolutionMode.Same, null);
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
private MultiLayerNetwork getSpaceToBatchNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new SpaceToBatchLayer.Builder()
|
||||
return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder()
|
||||
.blocks(2, 2)
|
||||
.dataFormat(format)
|
||||
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
|
||||
} else {
|
||||
return getNetWithLayer(new SpaceToBatchLayer.Builder()
|
||||
return getNetWithLayer(dataType,new SpaceToBatchLayer.Builder()
|
||||
.blocks(2, 2)
|
||||
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
private MultiLayerNetwork getLocallyConnectedNet(DataType dataType,CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new LocallyConnected2D.Builder()
|
||||
return getNetWithLayer(dataType,new LocallyConnected2D.Builder()
|
||||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
.activation(Activation.TANH)
|
||||
|
@ -756,7 +783,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
.nOut(3)
|
||||
.build(), format, cm, null);
|
||||
} else {
|
||||
return getNetWithLayer(new LocallyConnected2D.Builder()
|
||||
return getNetWithLayer(dataType,new LocallyConnected2D.Builder()
|
||||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
.activation(Activation.TANH)
|
||||
|
@ -765,9 +792,9 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
|
||||
private MultiLayerNetwork getNetWithLayer(DataType dataType,Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) {
|
||||
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
|
||||
.dataType(this.dataType)
|
||||
.dataType(dataType)
|
||||
.seed(12345)
|
||||
.convolutionMode(cm)
|
||||
.list()
|
||||
|
@ -794,13 +821,13 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
return net;
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
|
||||
private MultiLayerNetwork getGlobalPoolingNet(DataType dataType,CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
|
||||
return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt)
|
||||
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
|
||||
.build(), format, ConvolutionMode.Same, null);
|
||||
} else {
|
||||
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
|
||||
return getNetWithLayer(dataType,new GlobalPoolingLayer.Builder(pt)
|
||||
.build(), format, ConvolutionMode.Same, null);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,8 +45,11 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
|||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.enums.RnnDataFormat;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -61,30 +64,29 @@ import org.deeplearning4j.nn.workspace.ArrayType;
|
|||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.deeplearning4j.nn.conf.RNNFormat.NCW;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
@DisplayName("Bidirectional Test")
|
||||
class BidirectionalTest extends BaseDL4JTest {
|
||||
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public BidirectionalTest(RNNFormat rnnDataFormat) {
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params() {
|
||||
return RNNFormat.values();
|
||||
public static Stream<Arguments> params() {
|
||||
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Compare Implementations")
|
||||
void compareImplementations() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void compareImplementations(RNNFormat rnnDataFormat) {
|
||||
for (WorkspaceMode wsm : WorkspaceMode.values()) {
|
||||
log.info("*** Starting workspace mode: " + wsm);
|
||||
// Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params
|
||||
|
@ -147,9 +149,11 @@ class BidirectionalTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Compare Implementations Comp Graph")
|
||||
void compareImplementationsCompGraph() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void compareImplementationsCompGraph(RNNFormat rnnFormat) {
|
||||
// for(WorkspaceMode wsm : WorkspaceMode.values()) {
|
||||
for (WorkspaceMode wsm : new WorkspaceMode[] { WorkspaceMode.NONE, WorkspaceMode.ENABLED }) {
|
||||
log.info("*** Starting workspace mode: " + wsm);
|
||||
|
@ -187,8 +191,8 @@ class BidirectionalTest extends BaseDL4JTest {
|
|||
Gradient g2 = net2.gradient();
|
||||
assertEquals(g1.gradient(), g2.gradient());
|
||||
// Ensure updates are equal:
|
||||
ComputationGraphUpdater u1 = (ComputationGraphUpdater) net1.getUpdater();
|
||||
ComputationGraphUpdater u2 = (ComputationGraphUpdater) net2.getUpdater();
|
||||
ComputationGraphUpdater u1 = net1.getUpdater();
|
||||
ComputationGraphUpdater u2 = net2.getUpdater();
|
||||
assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray());
|
||||
u1.update(g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
|
||||
u2.update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces());
|
||||
|
@ -205,7 +209,9 @@ class BidirectionalTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Serialization")
|
||||
void testSerialization() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testSerialization(RNNFormat rnnDataFormat) throws Exception {
|
||||
for (WorkspaceMode wsm : WorkspaceMode.values()) {
|
||||
log.info("*** Starting workspace mode: " + wsm);
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -242,7 +248,9 @@ class BidirectionalTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Serialization Comp Graph")
|
||||
void testSerializationCompGraph() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testSerializationCompGraph(RNNFormat rnnDataFormat) throws Exception {
|
||||
for (WorkspaceMode wsm : WorkspaceMode.values()) {
|
||||
log.info("*** Starting workspace mode: " + wsm);
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -277,7 +285,9 @@ class BidirectionalTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Simple Bidirectional")
|
||||
void testSimpleBidirectional() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testSimpleBidirectional(RNNFormat rnnDataFormat) {
|
||||
for (WorkspaceMode wsm : WorkspaceMode.values()) {
|
||||
log.info("*** Starting workspace mode: " + wsm);
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -362,7 +372,9 @@ class BidirectionalTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Simple Bidirectional Comp Graph")
|
||||
void testSimpleBidirectionalCompGraph() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testSimpleBidirectionalCompGraph(RNNFormat rnnDataFormat) {
|
||||
for (WorkspaceMode wsm : WorkspaceMode.values()) {
|
||||
log.info("*** Starting workspace mode: " + wsm);
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
*/
|
||||
package org.deeplearning4j.nn.layers.recurrent;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
|
@ -34,9 +33,12 @@ import org.deeplearning4j.nn.params.GravesBidirectionalLSTMParamInitializer;
|
|||
import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -44,31 +46,29 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.learning.config.AdaGrad;
|
||||
import org.nd4j.linalg.learning.config.NoOp;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@DisplayName("Graves Bidirectional LSTM Test")
|
||||
class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
||||
|
||||
private double score = 0.0;
|
||||
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat) {
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
|
||||
public static Stream<Arguments> params(){
|
||||
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params() {
|
||||
return RNNFormat.values();
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Test Bidirectional LSTM Graves Forward Basic")
|
||||
void testBidirectionalLSTMGravesForwardBasic() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
void testBidirectionalLSTMGravesForwardBasic(RNNFormat rnnDataFormat) {
|
||||
// Very basic test of forward prop. of LSTM layer with a time series.
|
||||
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
|
||||
int nIn = 13;
|
||||
|
@ -110,19 +110,21 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Bidirectional LSTM Graves Backward Basic")
|
||||
void testBidirectionalLSTMGravesBackwardBasic() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
void testBidirectionalLSTMGravesBackwardBasic(RNNFormat rnnDataFormat) {
|
||||
// Very basic test of backprop for mini-batch + time series
|
||||
// Essentially make sure it doesn't throw any exceptions, and provides output in the correct shape.
|
||||
testGravesBackwardBasicHelper(13, 3, 17, 10, 7);
|
||||
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 7);
|
||||
// Edge case: miniBatchSize = 1
|
||||
testGravesBackwardBasicHelper(13, 3, 17, 1, 7);
|
||||
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 7);
|
||||
// Edge case: timeSeriesLength = 1
|
||||
testGravesBackwardBasicHelper(13, 3, 17, 10, 1);
|
||||
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 10, 1);
|
||||
// Edge case: both miniBatchSize = 1 and timeSeriesLength = 1
|
||||
testGravesBackwardBasicHelper(13, 3, 17, 1, 1);
|
||||
testGravesBackwardBasicHelper(rnnDataFormat,13, 3, 17, 1, 1);
|
||||
}
|
||||
|
||||
private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) {
|
||||
private void testGravesBackwardBasicHelper(RNNFormat rnnDataFormat,int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) {
|
||||
INDArray inputData = (rnnDataFormat == RNNFormat.NCW) ? Nd4j.ones(miniBatchSize, nIn, timeSeriesLength) : Nd4j.ones(miniBatchSize, timeSeriesLength, nIn);
|
||||
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat).dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()).build();
|
||||
long numParams = conf.getLayer().initializer().numParams(conf);
|
||||
|
@ -204,7 +206,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Get Set Parmas")
|
||||
void testGetSetParmas() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
void testGetSetParmas(RNNFormat rnnDataFormat) {
|
||||
final int nIn = 2;
|
||||
final int layerSize = 3;
|
||||
final int miniBatchSize = 2;
|
||||
|
@ -224,7 +228,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Simple Forwards And Backwards Activation")
|
||||
void testSimpleForwardsAndBackwardsActivation() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
void testSimpleForwardsAndBackwardsActivation(RNNFormat rnnDataFormat) {
|
||||
final int nIn = 2;
|
||||
final int layerSize = 3;
|
||||
final int miniBatchSize = 1;
|
||||
|
@ -342,7 +348,9 @@ class GravesBidirectionalLSTMTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("Test Gate Activation Fns Sanity Check")
|
||||
void testGateActivationFnsSanityCheck() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
void testGateActivationFnsSanityCheck(RNNFormat rnnDataFormat) {
|
||||
for (String gateAfn : new String[] { "sigmoid", "hardsigmoid" }) {
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat).build()).layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat).activation(Activation.TANH).build()).build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
|
|
|
@ -30,36 +30,35 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
@DisplayName("Mask Zero Layer Test")
|
||||
class MaskZeroLayerTest extends BaseDL4JTest {
|
||||
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public MaskZeroLayerTest(RNNFormat rnnDataFormat) {
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
public static Stream<Arguments> params() {
|
||||
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params() {
|
||||
return RNNFormat.values();
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Activate")
|
||||
void activate() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void activate(RNNFormat rnnDataFormat) {
|
||||
// GIVEN two examples where some of the timesteps are zero.
|
||||
INDArray ex1 = Nd4j.create(new double[][] { new double[] { 0, 3, 5 }, new double[] { 0, 0, 2 } });
|
||||
INDArray ex2 = Nd4j.create(new double[][] { new double[] { 0, 0, 2 }, new double[] { 0, 0, 2 } });
|
||||
|
@ -95,9 +94,12 @@ class MaskZeroLayerTest extends BaseDL4JTest {
|
|||
assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@DisplayName("Test Serialization")
|
||||
void testSerialization() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
void testSerialization(RNNFormat rnnDataFormat) {
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder().setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()).build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
|
|
@ -40,8 +40,10 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
|||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -51,30 +53,31 @@ import org.nd4j.common.primitives.Pair;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
@AllArgsConstructor
|
||||
public class RnnDataFormatTests extends BaseDL4JTest {
|
||||
|
||||
private boolean helpers;
|
||||
private boolean lastTimeStep;
|
||||
private boolean maskZeros;
|
||||
|
||||
@Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}")
|
||||
public static List params(){
|
||||
public static Stream<Arguments> params() {
|
||||
List<Object[]> ret = new ArrayList<>();
|
||||
for (boolean helpers: new boolean[]{true, false})
|
||||
for (boolean lastTimeStep: new boolean[]{true, false})
|
||||
for (boolean maskZero: new boolean[]{true, false})
|
||||
ret.add(new Object[]{helpers, lastTimeStep, maskZero});
|
||||
return ret;
|
||||
return ret.stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testSimpleRnn() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testSimpleRnn(boolean helpers,
|
||||
boolean lastTimeStep,
|
||||
boolean maskZeros
|
||||
) {
|
||||
try {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -107,7 +110,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLSTM() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testLSTM(boolean helpers,
|
||||
boolean lastTimeStep,
|
||||
boolean maskZeros) {
|
||||
try {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -141,7 +148,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testGraveLSTM() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testGraveLSTM(boolean helpers,
|
||||
boolean lastTimeStep,
|
||||
boolean maskZeros) {
|
||||
try {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -175,7 +186,11 @@ public class RnnDataFormatTests extends BaseDL4JTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testGraveBiLSTM() {
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testGraveBiLSTM(boolean helpers,
|
||||
boolean lastTimeStep,
|
||||
boolean maskZeros) {
|
||||
try {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
|
|
@ -34,14 +34,20 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
|||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.enums.RnnDataFormat;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.learning.config.AdaGrad;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
|
||||
import static org.deeplearning4j.nn.weights.WeightInit.XAVIER_UNIFORM;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
@ -50,20 +56,16 @@ import static org.nd4j.linalg.activations.Activation.TANH;
|
|||
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
|
||||
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestLastTimeStepLayer extends BaseDL4JTest {
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public TestLastTimeStepLayer(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Parameterized.Parameters(name="{0}")
|
||||
public static Object[] params(){
|
||||
return RNNFormat.values();
|
||||
public static Stream<Arguments> params(){
|
||||
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLastTimeStepVertex() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testLastTimeStepVertex(RNNFormat rnnDataFormat) {
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
|
||||
.addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
|
||||
|
@ -126,7 +128,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMaskingAndAllMasked(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testMaskingAndAllMasked(RNNFormat rnnDataFormat) {
|
||||
ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder()
|
||||
.optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT)
|
||||
.weightInit(XAVIER_UNIFORM)
|
||||
|
|
|
@ -36,8 +36,11 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.enums.RnnDataFormat;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -49,25 +52,23 @@ import org.nd4j.common.primitives.Pair;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestRnnLayers extends BaseDL4JTest {
|
||||
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public TestRnnLayers(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params(){
|
||||
return RNNFormat.values();
|
||||
public static Stream<Arguments> params(){
|
||||
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTimeStepIs3Dimensional() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat) {
|
||||
|
||||
int nIn = 12;
|
||||
int nOut = 3;
|
||||
|
@ -117,7 +118,9 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDropoutRecurrentLayers(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
String[] layerTypes = new String[]{"graves", "lstm", "simple"};
|
||||
|
@ -215,9 +218,11 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMismatchedInputLabelLength(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat){
|
||||
|
||||
for( int i=0; i<2; i++ ){
|
||||
for( int i = 0; i < 2; i++) {
|
||||
|
||||
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
|
||||
|
||||
|
|
|
@ -29,8 +29,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
|||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -38,25 +40,25 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.learning.config.NoOp;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
|
||||
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestSimpleRnn extends BaseDL4JTest {
|
||||
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public TestSimpleRnn(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params(){
|
||||
return RNNFormat.values();
|
||||
public static Stream<Arguments> params() {
|
||||
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimpleRnn(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testSimpleRnn(RNNFormat rnnDataFormat) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
int m = 3;
|
||||
|
@ -125,7 +127,9 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBiasInit(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testBiasInit(RNNFormat rnnDataFormat) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nIn = 5;
|
||||
int layerSize = 6;
|
||||
|
|
|
@ -37,8 +37,10 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
|||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -47,22 +49,22 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestTimeDistributed extends BaseDL4JTest {
|
||||
|
||||
private RNNFormat rnnDataFormat;
|
||||
|
||||
public TestTimeDistributed(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Parameterized.Parameters
|
||||
public static Object[] params(){
|
||||
return RNNFormat.values();
|
||||
public static Stream<Arguments> params(){
|
||||
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTimeDistributed(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("#params")
|
||||
public void testTimeDistributed(RNNFormat rnnDataFormat){
|
||||
for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
|
||||
|
||||
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
|
||||
|
@ -133,10 +135,12 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testTimeDistributedDense(){
|
||||
@MethodSource("#params")
|
||||
@ParameterizedTest
|
||||
public void testTimeDistributedDense(RNNFormat rnnDataFormat){
|
||||
|
||||
for( int rnnType=0; rnnType<3; rnnType++ ) {
|
||||
for( int ffType=0; ffType<3; ffType++ ) {
|
||||
for( int rnnType = 0; rnnType < 3; rnnType++ ) {
|
||||
for( int ffType = 0; ffType < 3; ffType++ ) {
|
||||
|
||||
Layer l0, l2;
|
||||
switch (rnnType) {
|
||||
|
|
|
@ -39,8 +39,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
|
|
@ -145,92 +145,6 @@
|
|||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.jetbrains.kotlin</groupId>
|
||||
<artifactId>kotlin-maven-plugin</artifactId>
|
||||
<version>1.4.30-M1</version>
|
||||
<configuration>
|
||||
<args>
|
||||
<arg>-Xjsr305=strict</arg>
|
||||
</args>
|
||||
<compilerPlugins>
|
||||
<plugin>spring</plugin>
|
||||
<plugin>jpa</plugin>
|
||||
</compilerPlugins>
|
||||
</configuration>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.jetbrains.kotlin</groupId>
|
||||
<artifactId>kotlin-maven-allopen</artifactId>
|
||||
<version>${kotlin.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jetbrains.kotlin</groupId>
|
||||
<artifactId>kotlin-maven-noarg</artifactId>
|
||||
<version>${kotlin.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>compile</id>
|
||||
<goals> <goal>compile</goal> </goals>
|
||||
<configuration>
|
||||
<sourceDirs>
|
||||
<sourceDir>${project.basedir}/src/main/stubs</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/main/kotlin</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/main/java</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/main/ops</sourceDir>
|
||||
</sourceDirs>
|
||||
</configuration>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>test-compile</id>
|
||||
<goals> <goal>test-compile</goal> </goals>
|
||||
<configuration>
|
||||
<sourceDirs>
|
||||
<sourceDir>${project.basedir}/src/test/stubs</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/test/kotlin</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/test/java</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/test/ops</sourceDir>
|
||||
</sourceDirs>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<!-- https://kotlinlang.org/docs/reference/using-maven.html -->
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<version>3.5.1</version>
|
||||
<executions>
|
||||
<!-- Replacing default-compile as it is treated specially by maven -->
|
||||
<execution>
|
||||
<id>default-compile</id>
|
||||
<phase>none</phase>
|
||||
</execution>
|
||||
<!-- Replacing default-testCompile as it is treated specially by maven -->
|
||||
<execution>
|
||||
<id>default-testCompile</id>
|
||||
<phase>none</phase>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>java-compile</id>
|
||||
<phase>compile</phase>
|
||||
<goals> <goal>compile</goal> </goals>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>java-test-compile</id>
|
||||
<phase>test-compile</phase>
|
||||
<goals> <goal>testCompile</goal> </goals>
|
||||
</execution>
|
||||
</executions>
|
||||
<configuration>
|
||||
<source>${java.version}</source>
|
||||
<target>${java.version}</target>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
|
@ -244,7 +158,10 @@
|
|||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jetbrains.kotlin</groupId>
|
||||
<artifactId>kotlin-stdlib-jdk8</artifactId>
|
||||
|
@ -261,11 +178,14 @@
|
|||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>samediff-import-tensorflow</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>samediff-import-onnx</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>compile</scope>
|
||||
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
|
|
@ -22,11 +22,6 @@ package org.nd4j;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.imports.tfgraphs.TFGraphTestAllLibnd4j;
|
||||
import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
|
||||
import org.nd4j.imports.tfgraphs.TFGraphTestList;
|
||||
import org.nd4j.imports.tfgraphs.TFGraphTestZooModels;
|
||||
import org.nd4j.imports.listeners.ImportModelDebugger;
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
|
@ -36,11 +31,6 @@ public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
|||
protected Set<Class<?>> getExclusions() {
|
||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
||||
return new HashSet<>(Arrays.asList(
|
||||
TFGraphTestAllSameDiff.class,
|
||||
TFGraphTestAllLibnd4j.class,
|
||||
TFGraphTestList.class,
|
||||
TFGraphTestZooModels.class,
|
||||
ImportModelDebugger.class //Run manually only, otherwise ignored
|
||||
));
|
||||
}
|
||||
|
||||
|
|
|
@ -20,19 +20,16 @@
|
|||
|
||||
package org.nd4j;
|
||||
|
||||
import org.bytedeco.javacpp.Loader;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.nd4j.autodiff.opvalidation.*;
|
||||
import org.nd4j.autodiff.validation.OpValidation;
|
||||
import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
|
||||
//import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.function.Function;
|
||||
|
||||
import static org.junit.Assume.assumeFalse;
|
||||
|
||||
|
@ -49,7 +46,7 @@ import static org.junit.Assume.assumeFalse;
|
|||
TransformOpValidation.class,
|
||||
|
||||
//TF import tests
|
||||
TFGraphTestAllSameDiff.class
|
||||
//TFGraphTestAllSameDiff.class
|
||||
//TFGraphTestAllLibnd4j.class
|
||||
})
|
||||
//IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test"
|
||||
|
|
|
@ -27,10 +27,12 @@ import org.apache.commons.io.FileUtils;
|
|||
import org.apache.commons.io.FilenameUtils;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.converters.ImportClassMapping;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.NoOp;
|
||||
import org.nd4j.linalg.api.ops.compat.CompatSparseToDense;
|
||||
|
@ -122,13 +124,11 @@ import java.util.regex.Matcher;
|
|||
import java.util.regex.Pattern;
|
||||
|
||||
@Disabled("No longer relevant after model import rewrite.")
|
||||
public class TestOpMapping extends BaseNd4jTest {
|
||||
public class TestOpMapping extends BaseNd4jTestWithBackends {
|
||||
|
||||
Set<Class<? extends DifferentialFunction>> subTypes;
|
||||
|
||||
public TestOpMapping(Nd4jBackend b){
|
||||
super(b);
|
||||
|
||||
public TestOpMapping() {
|
||||
Reflections reflections = new Reflections("org.nd4j");
|
||||
subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
|
||||
}
|
||||
|
@ -146,6 +146,8 @@ public class TestOpMapping extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOpMappingCoverage() throws Exception {
|
||||
Map<String, DifferentialFunction> opNameMapping = ImportClassMapping.getOpNameMapping();
|
||||
Map<String, DifferentialFunction> tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions();
|
||||
|
@ -196,7 +198,9 @@ public class TestOpMapping extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testOpsInNamespace() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOpsInNamespace(Nd4jBackend backend) throws Exception {
|
||||
//Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't
|
||||
// want to add to a namespace for some reason)
|
||||
//Note that we ignore "*Bp", "*Gradient", "*Derivative" etc ops
|
||||
|
@ -354,8 +358,11 @@ public class TestOpMapping extends BaseNd4jTest {
|
|||
s.add(Assign.class);
|
||||
}
|
||||
|
||||
@Test @Disabled
|
||||
public void generateOpClassList() throws Exception{
|
||||
@Test
|
||||
@Disabled
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void generateOpClassList(Nd4jBackend backend) throws Exception{
|
||||
Reflections reflections = new Reflections("org.nd4j");
|
||||
Set<Class<? extends DifferentialFunction>> subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
|
||||
|
||||
|
@ -366,12 +373,7 @@ public class TestOpMapping extends BaseNd4jTest {
|
|||
l.add(c);
|
||||
}
|
||||
|
||||
Collections.sort(l, new Comparator<Class<?>>() {
|
||||
@Override
|
||||
public int compare(Class<?> o1, Class<?> o2) {
|
||||
return o1.getName().compareTo(o2.getName());
|
||||
}
|
||||
});
|
||||
Collections.sort(l, Comparator.comparing(Class::getName));
|
||||
|
||||
for(Class<?> c : l){
|
||||
System.out.println(c.getName() + ".class,");
|
||||
|
|
|
@ -22,6 +22,8 @@ package org.nd4j.autodiff;
|
|||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -31,7 +33,7 @@ import org.nd4j.autodiff.samediff.internal.FrameIter;
|
|||
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||
import org.nd4j.autodiff.samediff.internal.memory.NoOpMemoryMgr;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -46,19 +48,17 @@ import java.util.Map;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestSessions extends BaseNd4jTest {
|
||||
|
||||
public TestSessions(Nd4jBackend b){
|
||||
super(b);
|
||||
}
|
||||
public class TestSessions extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInferenceSessionBasic(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInferenceSessionBasic(Nd4jBackend backend) {
|
||||
//So far: trivial test to check execution order
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -90,7 +90,9 @@ public class TestSessions extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testInferenceSessionBasic2(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInferenceSessionBasic2(Nd4jBackend backend) {
|
||||
//So far: trivial test to check execution order
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -126,7 +128,9 @@ public class TestSessions extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMergeSimple(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMergeSimple(Nd4jBackend backend) {
|
||||
//This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available...
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -162,7 +166,9 @@ public class TestSessions extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSwitchSimple(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSwitchSimple(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3,3);
|
||||
|
|
|
@ -21,10 +21,12 @@
|
|||
package org.nd4j.autodiff.internal;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.samediff.internal.DependencyList;
|
||||
import org.nd4j.autodiff.samediff.internal.DependencyTracker;
|
||||
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
@ -35,19 +37,18 @@ import java.util.Collections;
|
|||
import static junit.framework.TestCase.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class TestDependencyTracker extends BaseNd4jTest {
|
||||
public class TestDependencyTracker extends BaseNd4jTestWithBackends {
|
||||
|
||||
public TestDependencyTracker(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSimple(){
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSimple(Nd4jBackend backend){
|
||||
|
||||
DependencyTracker<String,String> dt = new DependencyTracker<>();
|
||||
|
||||
|
@ -93,8 +94,10 @@ public class TestDependencyTracker extends BaseNd4jTest {
|
|||
assertTrue(dt.isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSatisfiedBeforeAdd(){
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSatisfiedBeforeAdd(Nd4jBackend backend){
|
||||
DependencyTracker<String,String> dt = new DependencyTracker<>();
|
||||
|
||||
//Check different order of adding dependencies: i.e., mark X as satisfied, then add x -> y dependency
|
||||
|
@ -132,8 +135,10 @@ public class TestDependencyTracker extends BaseNd4jTest {
|
|||
assertFalse(dt.hasNewAllSatisfied());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMarkUnsatisfied(){
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMarkUnsatisfied(Nd4jBackend backend){
|
||||
|
||||
DependencyTracker<String,String> dt = new DependencyTracker<>();
|
||||
dt.addDependency("y", "x");
|
||||
|
@ -164,7 +169,9 @@ public class TestDependencyTracker extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIdentityDependencyTracker(){
|
||||
IdentityDependencyTracker<INDArray, String> dt = new IdentityDependencyTracker<>();
|
||||
assertTrue(dt.isEmpty());
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
package org.nd4j.autodiff.opvalidation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.validation.GradCheckUtil;
|
||||
|
@ -38,12 +40,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
|
||||
public class ActivationGradChecks extends BaseOpValidation {
|
||||
|
||||
public ActivationGradChecks(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testActivationGradientCheck1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testActivationGradientCheck1(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4));
|
||||
|
@ -61,7 +62,9 @@ public class ActivationGradChecks extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testActivationGradientCheck2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testActivationGradientCheck2(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4);
|
||||
|
|
|
@ -21,18 +21,14 @@
|
|||
package org.nd4j.autodiff.opvalidation;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
public abstract class BaseOpValidation extends BaseNd4jTest {
|
||||
public abstract class BaseOpValidation extends BaseNd4jTestWithBackends {
|
||||
|
||||
private DataType initialType;
|
||||
private DataType initialType = Nd4j.dataType();
|
||||
|
||||
public BaseOpValidation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
|
|
@ -27,6 +27,8 @@ import java.util.List;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.validation.OpValidation;
|
||||
|
@ -65,9 +67,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
|
||||
@Slf4j
|
||||
public class LayerOpValidation extends BaseOpValidation {
|
||||
public LayerOpValidation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
|
@ -75,7 +74,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testXwPlusB() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testXwPlusB(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
@ -109,7 +110,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReluLayer() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReluLayer(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
@ -137,7 +140,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBiasAdd() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBiasAdd(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
@ -161,7 +166,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testConv2d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConv2d(Nd4jBackend backend) {
|
||||
//avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling
|
||||
//Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d
|
||||
|
||||
|
@ -301,7 +308,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLrn2d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLrn2d(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};
|
||||
|
@ -342,7 +351,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testIm2Col() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIm2Col(Nd4jBackend backend) {
|
||||
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -381,7 +392,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testOutputShape() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOutputShape(Nd4jBackend backend) {
|
||||
long[] inSize = {1, 8, 8, 3};
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -431,7 +444,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testAvgPool() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAvgPool(Nd4jBackend backend) {
|
||||
long[] inSize = {1, 8, 8, 3}; //NHWC
|
||||
|
||||
Pooling2DConfig conf = Pooling2DConfig.builder()
|
||||
|
@ -474,7 +489,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testConv3d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConv3d(Nd4jBackend backend) {
|
||||
//Pooling3d, Conv3D, batch norm
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -576,7 +593,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testDepthWiseConv2dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDepthWiseConv2dBasic(Nd4jBackend backend) {
|
||||
int nIn = 3;
|
||||
int depthWise = 4;
|
||||
int kH = 2;
|
||||
|
@ -615,7 +634,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSeparableConv2dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSeparableConv2dBasic(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nIn = 2;
|
||||
int nOut = 3;
|
||||
|
@ -671,7 +692,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDeconv2dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDeconv2dBasic(Nd4jBackend backend) {
|
||||
int nIn = 2;
|
||||
int nOut = 3;
|
||||
int kH = 2;
|
||||
|
@ -715,7 +738,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testConv2dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConv2dBasic(Nd4jBackend backend) {
|
||||
int nIn = 3;
|
||||
int nOut = 4;
|
||||
int kH = 2;
|
||||
|
@ -756,7 +781,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMaxPoolingArgMax() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMaxPoolingArgMax(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nIn = 3;
|
||||
int kH = 2;
|
||||
|
@ -785,7 +812,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMaxPooling2dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMaxPooling2dBasic(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nIn = 3;
|
||||
int kH = 2;
|
||||
|
@ -843,7 +872,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAvgPooling2dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAvgPooling2dBasic(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nIn = 3;
|
||||
int kH = 2;
|
||||
|
@ -892,7 +923,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAvgPooling3dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAvgPooling3dBasic(Nd4jBackend backend) {
|
||||
int nIn = 3;
|
||||
int kH = 2;
|
||||
int kW = 2;
|
||||
|
@ -929,7 +962,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMaxPooling3dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMaxPooling3dBasic(Nd4jBackend backend) {
|
||||
int nIn = 3;
|
||||
int kH = 2;
|
||||
int kW = 2;
|
||||
|
@ -967,7 +1002,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testConv1dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConv1dBasic(Nd4jBackend backend) {
|
||||
int nIn = 3;
|
||||
int nOut = 4;
|
||||
int k = 2;
|
||||
|
@ -1002,7 +1039,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testConv1dCausal() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConv1dCausal(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int nIn = 3;
|
||||
int nOut = 4;
|
||||
|
@ -1051,7 +1090,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testConv1dForward() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConv1dForward(Nd4jBackend backend) {
|
||||
int nIn = 2;
|
||||
int nOut = 1;
|
||||
int kernel = 3;
|
||||
|
@ -1094,7 +1135,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testConv3dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConv3dBasic(Nd4jBackend backend) {
|
||||
int nIn = 3;
|
||||
int nOut = 4;
|
||||
int kH = 2;
|
||||
|
@ -1140,7 +1183,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDeConv3dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDeConv3dBasic(Nd4jBackend backend) {
|
||||
int nIn = 4;
|
||||
int nOut = 3;
|
||||
int kH = 2;
|
||||
|
@ -1185,7 +1230,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLayerNorm() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLayerNorm(Nd4jBackend backend) {
|
||||
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||
final INDArray standardized = random.ulike();
|
||||
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
||||
|
@ -1210,7 +1257,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLayerNorm4d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLayerNorm4d(Nd4jBackend backend) {
|
||||
int mb = 3;
|
||||
int ch = 4;
|
||||
for (boolean nchw : new boolean[]{true, false}) {
|
||||
|
@ -1242,7 +1291,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testLayerNormOP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLayerNormOP(Nd4jBackend backend) {
|
||||
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||
final INDArray standardized = random.ulike();
|
||||
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
||||
|
@ -1258,7 +1309,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLayerNormNoBias() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLayerNormNoBias(Nd4jBackend backend) {
|
||||
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||
final INDArray standardized = random.ulike();
|
||||
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
||||
|
@ -1281,7 +1334,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLayerNormOPNoBias() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLayerNormOPNoBias(Nd4jBackend backend) {
|
||||
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||
final INDArray standardized = random.ulike();
|
||||
Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1));
|
||||
|
@ -1296,7 +1351,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLayerNormNoDeviation() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLayerNormNoDeviation(Nd4jBackend backend) {
|
||||
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
random.putScalar(1, i, 7);
|
||||
|
@ -1326,36 +1383,36 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test()
|
||||
public void exceptionThrown_WhenConv1DConfigInvalid() {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
int nIn = 3;
|
||||
int nOut = 4;
|
||||
int k = 2;
|
||||
int mb = 3;
|
||||
int img = 28;
|
||||
public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
int nIn = 3;
|
||||
int nOut = 4;
|
||||
int k = 2;
|
||||
int mb = 3;
|
||||
int img = 28;
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
INDArray wArr = Nd4j.create(k, nIn, nOut);
|
||||
INDArray inArr = Nd4j.create(mb, nIn, img);
|
||||
SameDiff sd = SameDiff.create();
|
||||
INDArray wArr = Nd4j.create(k, nIn, nOut);
|
||||
INDArray inArr = Nd4j.create(mb, nIn, img);
|
||||
|
||||
SDVariable in = sd.var("in", inArr);
|
||||
SDVariable w = sd.var("W", wArr);
|
||||
SDVariable in = sd.var("in", inArr);
|
||||
SDVariable w = sd.var("W", wArr);
|
||||
|
||||
SDVariable[] vars = new SDVariable[]{in, w};
|
||||
SDVariable[] vars = new SDVariable[]{in, w};
|
||||
|
||||
Conv1DConfig conv1DConfig = Conv1DConfig.builder()
|
||||
.k(k).p(-1).s(0)
|
||||
.paddingMode(PaddingMode.VALID)
|
||||
.build();
|
||||
Conv1DConfig conv1DConfig = Conv1DConfig.builder()
|
||||
.k(k).p(-1).s(0)
|
||||
.paddingMode(PaddingMode.VALID)
|
||||
.build();
|
||||
|
||||
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
|
||||
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
|
||||
|
||||
});
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test()
|
||||
public void exceptionThrown_WhenConv2DConfigInvalid() {
|
||||
public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -1378,40 +1435,42 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test()
|
||||
public void exceptionThrown_WhenConf3DInvalid() {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
//NCDHW format
|
||||
int[] inSizeNCDHW = {2, 3, 4, 5, 5};
|
||||
//NCDHW format
|
||||
int[] inSizeNCDHW = {2, 3, 4, 5, 5};
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
||||
for (boolean ncdhw : new boolean[]{true, false}) {
|
||||
int nIn = inSizeNCDHW[1];
|
||||
int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW));
|
||||
for (boolean ncdhw : new boolean[]{true, false}) {
|
||||
int nIn = inSizeNCDHW[1];
|
||||
int[] shape = (ncdhw ? inSizeNCDHW : ncdhwToNdhwc(inSizeNCDHW));
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var("in", shape);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var("in", shape);
|
||||
|
||||
SDVariable out;
|
||||
String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape);
|
||||
SDVariable out;
|
||||
String msg = "0 - conv3d+bias+same, ncdhw=" + ncdhw + " - input " + Arrays.toString(shape);
|
||||
|
||||
SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC]
|
||||
SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10));
|
||||
out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder()
|
||||
.dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC)
|
||||
.isSameMode(true)
|
||||
.kH(2).kW(2).kD(2)
|
||||
.sD(1).sH(1).sW(-1).dW(-1)
|
||||
.build());
|
||||
}
|
||||
});
|
||||
SDVariable w0 = sd.var("w0", Nd4j.rand(new int[]{2, 2, 2, nIn, 3}).muli(10)); //[kD, kH, kW, iC, oC]
|
||||
SDVariable b0 = sd.var("b0", Nd4j.rand(new long[]{3}).muli(10));
|
||||
out = sd.cnn().conv3d(in, w0, b0, Conv3DConfig.builder()
|
||||
.dataFormat(ncdhw ? Conv3DConfig.NCDHW : Conv3DConfig.NDHWC)
|
||||
.isSameMode(true)
|
||||
.kH(2).kW(2).kD(2)
|
||||
.sD(1).sH(1).sW(-1).dW(-1)
|
||||
.build());
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLayerNormMixedOrders() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLayerNormMixedOrders(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
|
||||
INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f');
|
||||
|
@ -1458,7 +1517,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBiasAdd_nchw_nhwc() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBiasAdd_nchw_nhwc(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
for (boolean nchw : new boolean[]{true, false}) {
|
||||
|
@ -1489,6 +1550,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDepthwiseConv2D(){
|
||||
|
||||
int bS = 10;
|
||||
|
@ -1527,7 +1590,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void LSTMLayerTestCase1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void LSTMLayerTestCase1(Nd4jBackend backend) {
|
||||
|
||||
int bS = 5;
|
||||
int nIn = 3;
|
||||
|
@ -1602,7 +1667,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void LSTMLayerTestCase2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void LSTMLayerTestCase2(Nd4jBackend backend) {
|
||||
int bS = 5;
|
||||
int nIn = 3;
|
||||
int numUnits = 7;
|
||||
|
@ -1660,7 +1727,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void LSTMLayerTestCase3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void LSTMLayerTestCase3(Nd4jBackend backend) {
|
||||
int bS = 5;
|
||||
int nIn = 3;
|
||||
int numUnits = 7;
|
||||
|
@ -1721,7 +1790,9 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void GRUTestCase() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void GRUTestCase(Nd4jBackend backend) {
|
||||
int bS = 5;
|
||||
int nIn = 4;
|
||||
int nOut = 6;
|
||||
|
|
|
@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.OpValidationSuite;
|
||||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -43,9 +45,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
|
||||
@Slf4j
|
||||
public class LossOpValidation extends BaseOpValidation {
|
||||
public LossOpValidation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
|
@ -56,7 +56,9 @@ public class LossOpValidation extends BaseOpValidation {
|
|||
public static final Set<String> NO_BP_YET = new HashSet<>();
|
||||
|
||||
@Test
|
||||
public void testLoss2d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLoss2d(Nd4jBackend backend) {
|
||||
final List<String> oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax");
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -69,7 +71,7 @@ public class LossOpValidation extends BaseOpValidation {
|
|||
"absdiff", "cosine", "hinge", "huber", "log", "mse",
|
||||
"sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse",
|
||||
"sparsesoftmax"
|
||||
}) {
|
||||
}) {
|
||||
|
||||
|
||||
for(String weights : new String[]{"none", "scalar", "perExample", "perOutput"}) {
|
||||
|
@ -368,6 +370,8 @@ public class LossOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCosineDistance(){
|
||||
INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}});
|
||||
INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}});
|
||||
|
@ -386,6 +390,8 @@ public class LossOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testL2Loss(){
|
||||
|
||||
for( int rank=0; rank<=3; rank++ ){
|
||||
|
@ -428,7 +434,9 @@ public class LossOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNonZeroResult() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNonZeroResult(Nd4jBackend backend) {
|
||||
INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5);
|
||||
INDArray w = Nd4j.scalar(1.0);
|
||||
INDArray label = Nd4j.rand(DataType.DOUBLE, 10, 5);
|
||||
|
@ -486,6 +494,8 @@ public class LossOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void TestStdLossMixedDataType(){
|
||||
// Default Data Type in this test suite is Double.
|
||||
// This test used to throw an Exception that we have mixed data types.
|
||||
|
|
|
@ -23,6 +23,8 @@ package org.nd4j.autodiff.opvalidation;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.OpValidationSuite;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -78,13 +80,12 @@ import static org.junit.Assume.assumeNotNull;
|
|||
@Slf4j
|
||||
public class MiscOpValidation extends BaseOpValidation {
|
||||
|
||||
public MiscOpValidation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testGradientAutoBroadcast1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGradientAutoBroadcast1(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -171,7 +172,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGradientAutoBroadcast2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGradientAutoBroadcast2(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
@ -260,7 +263,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGradientAutoBroadcast3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGradientAutoBroadcast3(Nd4jBackend backend) {
|
||||
//These tests: output size > input sizes
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -368,7 +373,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testScatterOpGradients() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testScatterOpGradients(Nd4jBackend backend) {
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < 7; i++) {
|
||||
|
@ -470,6 +477,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testScatterUpdate(){
|
||||
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3);
|
||||
INDArray updates = Nd4j.create(new float[][]{
|
||||
|
@ -491,7 +500,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGatherGradient() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGatherGradient(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
@ -542,6 +553,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTrace(){
|
||||
//TODO need to work out how to handle shape_op for scalars...
|
||||
//OpValidationSuite.ignoreFailing();
|
||||
|
@ -567,7 +580,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testTensorGradTensorMmul() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTensorGradTensorMmul(Nd4jBackend backend) {
|
||||
OpValidationSuite.ignoreFailing();
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -589,7 +604,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMulGradient() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMulGradient(Nd4jBackend backend) {
|
||||
INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
|
||||
|
@ -654,22 +671,21 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testMmulGradientManual() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMmulGradientManual(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
Map<String, INDArray> inputs = new HashMap<>();
|
||||
inputs.put("x", sumInput);
|
||||
inputs.put("y", sumInput.dup());
|
||||
|
||||
sameDiff.defineFunction("mmulGradient", new SameDiffFunctionDefinition() {
|
||||
@Override
|
||||
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
|
||||
SDVariable input = sameDiff.var("x", inputs.get("x"));
|
||||
SDVariable input2 = sameDiff.var("y", inputs.get("y"));
|
||||
SDVariable exp = sameDiff.mmul(input, input2);
|
||||
SDVariable sum = sameDiff.sum(exp, Integer.MAX_VALUE);
|
||||
return new SDVariable[]{sum};
|
||||
}
|
||||
sameDiff.defineFunction("mmulGradient", (sameDiff1, inputs1, variableInputs) -> {
|
||||
SDVariable input = sameDiff1.var("x", inputs1.get("x"));
|
||||
SDVariable input2 = sameDiff1.var("y", inputs1.get("y"));
|
||||
SDVariable exp = sameDiff1.mmul(input, input2);
|
||||
SDVariable sum = sameDiff1.sum(exp, Integer.MAX_VALUE);
|
||||
return new SDVariable[]{sum};
|
||||
}, inputs);
|
||||
|
||||
|
||||
|
@ -698,6 +714,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMmulGradients(){
|
||||
int[] aShape = new int[]{2,3};
|
||||
int[] bShape = new int[]{3,4};
|
||||
|
@ -749,7 +767,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBatchMmulBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBatchMmulBasic(Nd4jBackend backend) {
|
||||
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873
|
||||
int M = 5;
|
||||
int N = 3;
|
||||
|
@ -774,7 +794,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testMmulWithTranspose() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMmulWithTranspose(Nd4jBackend backend) {
|
||||
|
||||
//Here: [x,3]^T * [x,4] = [3,4]
|
||||
|
||||
|
@ -811,6 +833,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMmulOutputSizeCalculation(){
|
||||
//[3,2] x [2,4] with result transpose: output shape [4,3]
|
||||
INDArray a = Nd4j.create(3,2);
|
||||
|
@ -820,7 +844,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
.transposeA(false)
|
||||
.transposeB(false)
|
||||
.transposeResult(true)
|
||||
.build());
|
||||
.build());
|
||||
|
||||
val outShapes = Nd4j.getExecutioner().calculateOutputShape(m);
|
||||
assertArrayEquals(new long[]{4,3}, outShapes.get(0).getShape());
|
||||
|
@ -843,6 +867,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testFillOp(){
|
||||
|
||||
INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT);
|
||||
|
@ -857,6 +883,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testClipByNorm(){
|
||||
//Expected: if array.norm2(1) is less than 1.0, not modified
|
||||
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
||||
|
@ -889,6 +917,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testClipByNorm2(){
|
||||
//Expected: if array.norm2(1) is less than 1.0, not modified
|
||||
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
||||
|
@ -932,6 +962,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testClipByNorm1(){
|
||||
//Expected: if array.norm2(1) is less than 1.0, not modified
|
||||
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
||||
|
@ -972,6 +1004,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testClipByNorm0(){
|
||||
//Expected: if array.norm2(0) is less than 1.0, not modified
|
||||
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
||||
|
@ -1001,6 +1035,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCumSum(){
|
||||
|
||||
List<String> failing = new ArrayList<>();
|
||||
|
@ -1066,6 +1102,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCumProd(){
|
||||
List<String> failing = new ArrayList<>();
|
||||
|
||||
|
@ -1134,6 +1172,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOneHot1(){
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
||||
|
@ -1164,6 +1204,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOneHotOp(){
|
||||
//https://www.tensorflow.org/api_docs/python/tf/one_hot
|
||||
//https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp
|
||||
|
@ -1178,7 +1220,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testOneHot2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOneHot2(Nd4jBackend backend) {
|
||||
|
||||
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
|
||||
|
||||
|
@ -1198,7 +1242,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testOneHot4() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOneHot4(Nd4jBackend backend) {
|
||||
|
||||
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
|
||||
|
||||
|
@ -1218,7 +1264,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testOneHot3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOneHot3(Nd4jBackend backend) {
|
||||
//https://github.com/deeplearning4j/deeplearning4j/issues/6872
|
||||
|
||||
//https://www.tensorflow.org/api_docs/python/tf/one_hot
|
||||
|
@ -1253,6 +1301,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLinspace(){
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10);
|
||||
|
@ -1266,6 +1316,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLinspace2(){
|
||||
OpValidationSuite.ignoreFailing(); //TODO 2019/01/18
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -1280,7 +1332,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testShapeFn() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testShapeFn(Nd4jBackend backend) {
|
||||
|
||||
INDArray in = Nd4j.create(new long[]{1, 2});
|
||||
|
||||
|
@ -1294,7 +1348,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testShapeFn2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testShapeFn2(Nd4jBackend backend) {
|
||||
|
||||
INDArray i = Nd4j.create(1,3);
|
||||
|
||||
|
@ -1307,6 +1363,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMergeRank1(){
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5));
|
||||
|
@ -1325,7 +1383,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDiagPart() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDiagPart(Nd4jBackend backend) {
|
||||
INDArray i = Nd4j.create(5,5);
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -1337,7 +1397,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDiagShapeFn() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDiagShapeFn(Nd4jBackend backend) {
|
||||
INDArray i = Nd4j.create(5,5);
|
||||
|
||||
CustomOp op = new DiagPart(i, null);
|
||||
|
@ -1350,6 +1412,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testZerosOnesLike(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -1392,6 +1456,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testZerosLikeOp(){
|
||||
|
||||
INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0);
|
||||
|
@ -1407,6 +1473,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConfusionMatrix(){
|
||||
DataType dt = DataType.DOUBLE;
|
||||
|
||||
|
@ -1443,6 +1511,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIsNonDecreasingIsStrictlyIncr(){
|
||||
List<long[]> shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3});
|
||||
|
||||
|
@ -1506,6 +1576,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testExtractImagePatches(){
|
||||
/*
|
||||
tf.reset_default_graph()
|
||||
|
@ -1553,6 +1625,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSegmentProdBpSimple(){
|
||||
|
||||
INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT);
|
||||
|
@ -1573,6 +1647,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMmulRank4() throws Exception {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -1608,6 +1684,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMmulRank4_simple(){
|
||||
|
||||
INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
|
||||
|
@ -1634,6 +1712,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNthElementRank1(){
|
||||
INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9});
|
||||
INDArray n = Nd4j.scalar(0);
|
||||
|
@ -1656,6 +1736,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTensorMmulShape(){
|
||||
INDArray a = Nd4j.create(new double[]{2}).reshape(1);
|
||||
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
|
||||
|
@ -1674,6 +1756,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTensorMmulShape2(){
|
||||
INDArray a = Nd4j.create(new double[]{2}).reshape(1);
|
||||
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
|
||||
|
@ -1682,6 +1766,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStopGradient(){
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -1701,6 +1787,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCheckNumerics(){
|
||||
OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927
|
||||
|
||||
|
@ -1744,7 +1832,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCheckNumerics2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCheckNumerics2(Nd4jBackend backend) {
|
||||
INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4);
|
||||
INDArray msg = Nd4j.scalar("My error message!");
|
||||
|
||||
|
@ -1757,6 +1847,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testHistogramFixedWidth(){
|
||||
//Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf]
|
||||
INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9);
|
||||
|
@ -1775,6 +1867,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDynamicPartition(){
|
||||
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
|
||||
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
|
||||
|
@ -1793,6 +1887,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testListDiff(){
|
||||
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
|
||||
INDArray y = Nd4j.createFromArray(3, 1);
|
||||
|
@ -1812,7 +1908,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDivideNoNan() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDivideNoNan(Nd4jBackend backend) {
|
||||
OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff()
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
@ -1836,7 +1934,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDigamma() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDigamma(Nd4jBackend backend) {
|
||||
|
||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
|
||||
|
@ -1851,7 +1951,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testFlatten() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testFlatten(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -1873,7 +1975,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testFusedBatchNorm() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testFusedBatchNorm(Nd4jBackend backend) {
|
||||
OpValidationSuite.ignoreFailing();
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -1918,7 +2022,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testIgamma() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIgamma(Nd4jBackend backend) {
|
||||
|
||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
|
@ -1934,7 +2040,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testIgammaC() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIgammaC(Nd4jBackend backend) {
|
||||
|
||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
|
@ -1951,7 +2059,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLgamma() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLgamma(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -1976,7 +2086,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLu() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLu(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -2007,7 +2119,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMatrixBandPart() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMatrixBandPart(Nd4jBackend backend) {
|
||||
OpValidationSuite.ignoreFailing();
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -2037,7 +2151,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPolygamma() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPolygamma(Nd4jBackend backend) {
|
||||
|
||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
|
@ -2053,7 +2169,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testTriangularSolve() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTriangularSolve(Nd4jBackend backend) {
|
||||
|
||||
INDArray a = Nd4j.createFromArray(new float[]{
|
||||
3.f, 0.f, 0.f, 0.f,
|
||||
|
@ -2077,7 +2195,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBiasAdd() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBiasAdd(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -2106,7 +2226,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBiasAddGrad() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBiasAddGrad(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -2126,7 +2248,9 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRoll() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRoll(Nd4jBackend backend) {
|
||||
|
||||
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
|
||||
21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}).
|
||||
|
@ -2146,6 +2270,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSeqMask(){
|
||||
INDArray arr = Nd4j.createFromArray(1,2,3);
|
||||
INDArray maxLen = Nd4j.scalar(4);
|
||||
|
|
|
@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.OpValidationSuite;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -51,12 +53,11 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
@Slf4j
|
||||
public class RandomOpValidation extends BaseOpValidation {
|
||||
|
||||
public RandomOpValidation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRandomOpsSDVarShape() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRandomOpsSDVarShape(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
||||
|
@ -157,7 +158,9 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRandomOpsLongShape() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRandomOpsLongShape(Nd4jBackend backend) {
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
||||
for (long[] shape : Arrays.asList(new long[]{1000}, new long[]{100, 10}, new long[]{40, 5, 5})) {
|
||||
|
@ -283,6 +286,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRandomBinomial(){
|
||||
|
||||
INDArray z = Nd4j.create(new long[]{10});
|
||||
|
@ -293,7 +298,9 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testUniformRankSimple() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testUniformRankSimple(Nd4jBackend backend) {
|
||||
|
||||
INDArray arr = Nd4j.createFromArray(new double[]{100.0});
|
||||
// OpTestCase tc = new OpTestCase(DynamicCustomOp.builder("randomuniform")
|
||||
|
@ -325,7 +332,9 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testRandomExponential() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRandomExponential(Nd4jBackend backend) {
|
||||
long length = 1_000_000;
|
||||
INDArray shape = Nd4j.createFromArray(new double[]{length});
|
||||
INDArray out = Nd4j.createUninitialized(new long[]{length});
|
||||
|
@ -347,6 +356,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRange(){
|
||||
//Technically deterministic, not random...
|
||||
|
||||
|
@ -380,6 +391,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAllEmptyReduce(){
|
||||
INDArray x = Nd4j.createFromArray(true, true, true);
|
||||
All all = new All(x);
|
||||
|
@ -389,6 +402,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testUniformDtype(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
|
||||
|
@ -417,6 +432,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRandomExponential2(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("random_exponential")
|
||||
|
|
|
@ -24,6 +24,8 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.validation.OpTestCase;
|
||||
import org.nd4j.autodiff.validation.OpValidation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -51,10 +53,6 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
|
||||
private DataType initialType;
|
||||
|
||||
public ReductionBpOpValidation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void before() {
|
||||
Nd4j.create(1);
|
||||
|
@ -71,14 +69,16 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@AfterEach
|
||||
public void tearDown() {
|
||||
public void tearDown(Nd4jBackend backend) {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testReduceSumBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReduceSumBP(Nd4jBackend backend) {
|
||||
//Full array reduction
|
||||
|
||||
//reduce_sum_bp op: has 2 inputs (original pre-reduce input, and gradient at output (epsilon))
|
||||
|
@ -104,7 +104,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReduceSumAlongDim0BP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReduceSumAlongDim0BP(Nd4jBackend backend) {
|
||||
//Reduction along dimension
|
||||
//Inputs/outputs as before - but note that the output is no longer a scalar
|
||||
|
||||
|
@ -130,7 +132,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReduceSumAlongDim1BP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReduceSumAlongDim1BP(Nd4jBackend backend) {
|
||||
//Reduction along dimension
|
||||
//Inputs/outputs as before - but note that the output is no longer a scalar
|
||||
|
||||
|
@ -158,7 +162,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testMeanBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMeanBP(Nd4jBackend backend) {
|
||||
|
||||
//dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j))
|
||||
// = 1/N * dL/dOut
|
||||
|
@ -189,7 +195,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMeanBP_Rank1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMeanBP_Rank1(Nd4jBackend backend) {
|
||||
INDArray dLdOut = Nd4j.scalar(0.5);
|
||||
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
|
||||
INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5 / 3);
|
||||
|
@ -202,7 +210,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMeanAlongDim0BP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMeanAlongDim0BP(Nd4jBackend backend) {
|
||||
//Reduction along dimension
|
||||
//Inputs/outputs as before - but note that the output is no longer a scalar
|
||||
|
||||
|
@ -230,7 +240,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMeanAlongDim1BP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMeanAlongDim1BP(Nd4jBackend backend) {
|
||||
//Reduction along dimension
|
||||
//Inputs/outputs as before - but note that the output is no longer a scalar
|
||||
|
||||
|
@ -258,7 +270,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testMinBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMinBP(Nd4jBackend backend) {
|
||||
//Full array min reduction
|
||||
|
||||
//dL/dIn_i = dL/dOut * dOut/dIn_i
|
||||
|
@ -297,7 +311,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMinAlongDimensionBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMinAlongDimensionBP(Nd4jBackend backend) {
|
||||
//Full array min reduction
|
||||
|
||||
//dL/dIn_i = dL/dOut * dOut/dIn_i
|
||||
|
@ -340,7 +356,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMaxBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMaxBP(Nd4jBackend backend) {
|
||||
//Full array max reduction
|
||||
|
||||
//dL/dIn_i = dL/dOut * dOut/dIn_i
|
||||
|
@ -370,7 +388,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMaxAlongDimensionBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMaxAlongDimensionBP(Nd4jBackend backend) {
|
||||
//Full array min reduction
|
||||
|
||||
//dL/dIn_i = dL/dOut * dOut/dIn_i
|
||||
|
@ -413,7 +433,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testProdBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testProdBP(Nd4jBackend backend) {
|
||||
//Full array product reduction
|
||||
|
||||
//dL/dIn_i = dL/dOut * dOut/dIn_i
|
||||
|
@ -442,7 +464,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testProdAlongDimensionBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testProdAlongDimensionBP(Nd4jBackend backend) {
|
||||
//dL/dIn_i = dL/dOut * dOut/dIn_i
|
||||
// = dL/dOut * d(prod(in))/dIn_i
|
||||
// = dL/dOut * (prod(in) / in_i)
|
||||
|
@ -498,7 +522,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStdevBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStdevBP(Nd4jBackend backend) {
|
||||
//If out = stdev(in) then:
|
||||
//dL/dIn = dL/dOut * dOut/dIn
|
||||
//dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
|
||||
|
@ -534,7 +560,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStdevBP_Rank1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStdevBP_Rank1(Nd4jBackend backend) {
|
||||
INDArray dLdOut = Nd4j.scalar(0.5);
|
||||
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
|
||||
double stdev = preReduceInput.stdNumber(true).doubleValue();
|
||||
|
@ -555,7 +583,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStdevAlongDimensionBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStdevAlongDimensionBP(Nd4jBackend backend) {
|
||||
//If out = stdev(in) then:
|
||||
//dL/dIn = dL/dOut * dOut/dIn
|
||||
//dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
|
||||
|
@ -600,7 +630,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testVarianceBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testVarianceBP(Nd4jBackend backend) {
|
||||
//If out = variance(in) then:
|
||||
//dL/dIn = dL/dOut * dOut/dIn
|
||||
//dOut/dIn_i = 2*(in_i-mean)/(n-1)
|
||||
|
@ -636,7 +668,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testVarianceAlongDimensionBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testVarianceAlongDimensionBP(Nd4jBackend backend) {
|
||||
//If out = variance(in) then:
|
||||
//dL/dIn = dL/dOut * dOut/dIn
|
||||
//dOut/dIn_i = 2*(in_i-mean)/(n-1)
|
||||
|
@ -678,7 +712,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testCumSumBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCumSumBP(Nd4jBackend backend) {
|
||||
//Standard case, non-reverse, non-exclusive
|
||||
//dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i
|
||||
// = sum_j dL/dOut_j * d(in_0 + ... + in_j)/dIn_i
|
||||
|
@ -748,7 +784,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testNorm2Bp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNorm2Bp(Nd4jBackend backend) {
|
||||
//dL/dIn = dL/dOut * dOut/dIn
|
||||
// = dL/dOut * x/|x|_2
|
||||
|
||||
|
@ -775,7 +813,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNorm2AlongDimensionBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNorm2AlongDimensionBP(Nd4jBackend backend) {
|
||||
//dL/dIn = dL/dOut * dOut/dIn
|
||||
// = dL/dOut * x/|x|_2
|
||||
|
||||
|
@ -808,7 +848,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNorm1Bp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNorm1Bp(Nd4jBackend backend) {
|
||||
//dL/dIn = dL/dOut * dOut/dIn
|
||||
// = dL/dOut * sgn(in)
|
||||
|
||||
|
@ -835,7 +877,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNorm1AlongDimensionBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNorm1AlongDimensionBP(Nd4jBackend backend) {
|
||||
//dL/dIn = dL/dOut * dOut/dIn
|
||||
// = dL/dOut * sgn(in)
|
||||
|
||||
|
@ -867,7 +911,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNormMaxBp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNormMaxBp(Nd4jBackend backend) {
|
||||
//out = max_i (|in_i|)
|
||||
//dL/dIn = dL/dOut * dOut/dIn
|
||||
// = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise)
|
||||
|
@ -897,7 +943,9 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNormMaxAlongDimensionBP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNormMaxAlongDimensionBP(Nd4jBackend backend) {
|
||||
//out = max_i (|in_i|)
|
||||
//dL/dIn = dL/dOut * dOut/dIn
|
||||
// = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise)
|
||||
|
|
|
@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
|
|||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.OpValidationSuite;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -76,16 +77,13 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
|
||||
public class ReductionOpValidation extends BaseOpValidation {
|
||||
|
||||
|
||||
public ReductionOpValidation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStdev() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStdev(Nd4jBackend backend) {
|
||||
List<String> errors = new ArrayList<>();
|
||||
|
||||
for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE)) {
|
||||
|
@ -111,7 +109,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testZeroCount() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testZeroCount(Nd4jBackend backend) {
|
||||
List<String> allFailed = new ArrayList<>();
|
||||
for (int i = 0; i < 21; i++) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -145,7 +145,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testZeroFraction() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testZeroFraction(Nd4jBackend backend) {
|
||||
List<String> allFailed = new ArrayList<>();
|
||||
for (int i = 0; i < 2; i++) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -175,7 +177,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReductionGradientsSimple() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReductionGradientsSimple(Nd4jBackend backend) {
|
||||
//OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
|
||||
//Test reductions: final and only function
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -344,7 +348,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReductionGradients1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReductionGradients1(Nd4jBackend backend) {
|
||||
//Test reductions: final, but *not* the only function
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -472,7 +478,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReductionGradients2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReductionGradients2(Nd4jBackend backend) {
|
||||
//Test reductions: NON-final function
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -650,7 +658,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testReduce3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReduce3(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
int d0 = 3;
|
||||
|
@ -755,7 +765,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMoments() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMoments(Nd4jBackend backend) {
|
||||
for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) {
|
||||
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
|
||||
|
@ -787,9 +799,11 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMomentsOp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMomentsOp(Nd4jBackend backend) {
|
||||
int[] axes = new int[]{0};
|
||||
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
|
||||
INDArray outMean = Nd4j.createUninitialized(new long[]{4});
|
||||
INDArray outVar = Nd4j.createUninitialized(new long[]{4});
|
||||
|
@ -804,7 +818,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNormalizeMomentsOp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNormalizeMomentsOp(Nd4jBackend backend) {
|
||||
INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
|
||||
INDArray ssSum = data.sum(0);
|
||||
INDArray ssSqSum = data.mul(data).sum(0);
|
||||
|
@ -824,7 +840,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAllAny() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAllAny(Nd4jBackend backend) {
|
||||
|
||||
INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4);
|
||||
INDArray allOnes = Nd4j.ones(DataType.FLOAT, 3, 4);
|
||||
|
@ -852,7 +870,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testIndexAccum() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIndexAccum(Nd4jBackend backend) {
|
||||
List<String> failed = new ArrayList<>();
|
||||
List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/);
|
||||
|
||||
|
@ -941,7 +961,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testReduce3_2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReduce3_2(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
int d0 = 3;
|
||||
|
@ -1039,7 +1061,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReductionsBackwards() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReductionsBackwards(Nd4jBackend backend) {
|
||||
// for (int i = 0; i < 7; i++) {
|
||||
int i=5;
|
||||
{
|
||||
|
@ -1108,6 +1132,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDotProductAttention(){
|
||||
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
|
||||
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
|
||||
|
@ -1127,12 +1153,14 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
t.norm1("out");
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.expectedOutput("out", finalOut)
|
||||
.gradientCheck(true));
|
||||
.expectedOutput("out", finalOut)
|
||||
.gradientCheck(true));
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDotProductAttentionWithMask(){
|
||||
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
|
||||
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
|
||||
|
@ -1163,6 +1191,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDotProductAttentionMultiHeadInputWithMask(){
|
||||
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
|
||||
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
|
||||
|
@ -1194,6 +1224,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDotProductAttentionMultiHeadInput(){
|
||||
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
|
||||
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
|
||||
|
@ -1221,6 +1253,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMultiHeadedDotProductAttention(){
|
||||
final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
|
||||
final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
|
||||
|
@ -1272,6 +1306,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDotProductAttentionWeirdInputs(){
|
||||
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
|
||||
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
|
||||
|
@ -1309,6 +1345,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMultiHeadedDotProductAttentionWeirdInputs(){
|
||||
final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
|
||||
final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
|
||||
|
@ -1366,7 +1404,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
}
|
||||
@Test
|
||||
public void testSufficientStatisticsOp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSufficientStatisticsOp(Nd4jBackend backend) {
|
||||
INDArray data = Nd4j.createFromArray(new double[]{
|
||||
5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
|
||||
1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5
|
||||
|
@ -1392,7 +1432,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStandardDeviation() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStandardDeviation(Nd4jBackend backend) {
|
||||
|
||||
for (boolean keepDims : new boolean[]{false, true}) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
@ -1419,7 +1461,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSquaredNorm() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSquaredNorm(Nd4jBackend backend) {
|
||||
|
||||
for (boolean keepDims : new boolean[]{false, true}) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
@ -1442,7 +1486,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testShannonEntropy() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testShannonEntropy(Nd4jBackend backend) {
|
||||
OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
@ -1462,7 +1508,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEntropy() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEntropy(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -1481,7 +1529,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAMean() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAMean(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -1502,7 +1552,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMean() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMean(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -1523,7 +1575,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNorm1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNorm1(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -1544,7 +1598,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNorm2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNorm2(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -1565,7 +1621,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNormMax() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNormMax(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
|
@ -1586,7 +1644,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSoftmaxCrossEntropyWithLogitsLoss() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSoftmaxCrossEntropyWithLogitsLoss(Nd4jBackend backend) {
|
||||
OpValidationSuite.ignoreFailing();
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
|
|
@ -22,6 +22,8 @@ package org.nd4j.autodiff.opvalidation;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -43,12 +45,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
|
||||
@Slf4j
|
||||
public class RnnOpValidation extends BaseOpValidation {
|
||||
public RnnOpValidation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRnnBlockCell(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRnnBlockCell(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int mb = 2;
|
||||
int nIn = 3;
|
||||
|
@ -147,7 +148,9 @@ public class RnnOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testRnnBlockCellManualTFCompare() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRnnBlockCellManualTFCompare(Nd4jBackend backend) {
|
||||
//Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS"
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -209,6 +212,8 @@ public class RnnOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGRUCell(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int mb = 2;
|
||||
|
|
|
@ -28,6 +28,8 @@ import lombok.val;
|
|||
import org.apache.commons.math3.linear.LUDecomposition;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.OpValidationSuite;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -67,9 +69,6 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
|
|||
|
||||
@Slf4j
|
||||
public class ShapeOpValidation extends BaseOpValidation {
|
||||
public ShapeOpValidation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
/*
|
||||
To test:
|
||||
|
@ -83,7 +82,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
*/
|
||||
|
||||
@Test
|
||||
public void testConcat() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConcat(Nd4jBackend backend) {
|
||||
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
|
||||
int[] concatDim = new int[]{0, 0, 0};
|
||||
List<List<int[]>> origShapes = new ArrayList<>();
|
||||
|
@ -123,7 +124,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReshapeGradient() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReshapeGradient(Nd4jBackend backend) {
|
||||
//https://github.com/deeplearning4j/deeplearning4j/issues/6873
|
||||
|
||||
int[] origShape = new int[]{3, 4, 5};
|
||||
|
@ -159,7 +162,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPermuteGradient() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPermuteGradient(Nd4jBackend backend) {
|
||||
int[] origShape = new int[]{3, 4, 5};
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
@ -197,6 +202,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRank(){
|
||||
|
||||
List<long[]> inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5});
|
||||
|
@ -224,7 +231,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testExpandDimsGradient() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testExpandDimsGradient(Nd4jBackend backend) {
|
||||
val origShape = new long[]{3, 4};
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
@ -280,7 +289,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSqueezeGradient() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSqueezeGradient(Nd4jBackend backend) {
|
||||
val origShape = new long[]{3, 4, 5};
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
@ -344,7 +355,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSliceGradient() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSliceGradient(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
//Order here: original shape, begin, size
|
||||
|
@ -434,7 +447,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStridedSliceGradient() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStridedSliceGradient(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
//Order here: original shape, begin, size
|
||||
|
@ -497,7 +512,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testMerge() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMerge(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
@ -573,7 +590,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test()
|
||||
public void testStack() {
|
||||
public void testStack(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
@ -664,7 +681,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testUnStack() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testUnStack(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
@ -752,7 +771,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testTile() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTile(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
List<int[]> tileArg = Arrays.asList(
|
||||
|
@ -824,6 +845,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTileBp(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -857,6 +880,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTileBp2(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -891,7 +916,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testReshape() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReshape(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4);
|
||||
SDVariable x = sameDiff.var("x", arr);
|
||||
|
@ -907,7 +934,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReshape2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReshape2(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
int[] origShape = new int[]{3, 4, 5};
|
||||
|
||||
|
@ -930,7 +959,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testTranspose() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTranspose(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
|
||||
SDVariable x = sameDiff.var("x", arr);
|
||||
|
@ -942,6 +973,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTransposeOp(){
|
||||
|
||||
INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3);
|
||||
|
@ -955,7 +988,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testShape() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testShape(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
val shape = new long[]{2, 3};
|
||||
SDVariable x = sameDiff.var("x", shape);
|
||||
|
@ -970,7 +1005,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSize() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSize(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
val shape = new long[]{2, 3};
|
||||
SDVariable x = sameDiff.var("x", DataType.FLOAT, shape);
|
||||
|
@ -984,7 +1021,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDiagShapeFn() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDiagShapeFn(Nd4jBackend backend) {
|
||||
INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4);
|
||||
|
||||
OpTestCase op = new OpTestCase(new DiagPart(i, null));
|
||||
|
@ -998,6 +1037,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPermute(){
|
||||
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
|
||||
INDArray exp = in.permute(0,1,2); //No op
|
||||
|
@ -1012,6 +1053,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPermute2(){
|
||||
for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) {
|
||||
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
|
||||
|
@ -1032,6 +1075,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConstant(){
|
||||
//OpValidationSuite.ignoreFailing();
|
||||
|
||||
|
@ -1059,6 +1104,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testUnstackEdgeCase2(){
|
||||
for( int i=0; i<3; i++ ) {
|
||||
|
||||
|
@ -1073,7 +1120,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void invertPermutation() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void invertPermutation(Nd4jBackend backend) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
INDArray ia = Nd4j.create(new float[] {3, 4, 0, 2, 1}).castTo(DataType.INT);
|
||||
|
@ -1090,6 +1139,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGatherNd(){
|
||||
|
||||
List<INDArray> indices = new ArrayList<>();
|
||||
|
@ -1128,7 +1179,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testReverseSequence() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReverseSequence(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
float[] input_data = new float[]{
|
||||
1, 2, 3,
|
||||
|
@ -1174,6 +1227,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMatrixDeterminant(){
|
||||
OpValidationSuite.ignoreFailing(); //Gradient check failing
|
||||
|
||||
|
@ -1195,6 +1250,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDeterminant22(){
|
||||
OpValidationSuite.ignoreFailing(); //Gradient check failing
|
||||
|
||||
|
@ -1219,6 +1276,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMatrixDeterminant3(){
|
||||
OpValidationSuite.ignoreFailing(); //Gradient checks failing
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -1250,6 +1309,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMatrixDeterminant4(){
|
||||
OpValidationSuite.ignoreFailing(); //Gradient checks failing
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -1270,6 +1331,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSegmentOps(){
|
||||
OpValidationSuite.ignoreFailing();
|
||||
//https://github.com/deeplearning4j/deeplearning4j/issues/6952
|
||||
|
@ -1362,6 +1425,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSegmentMean(){
|
||||
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3);
|
||||
INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2);
|
||||
|
@ -1382,7 +1447,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSequenceMask() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSequenceMask(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
|
||||
// arr is not trainable, so it's constant in model
|
||||
|
@ -1391,10 +1458,10 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
// Test with static max len
|
||||
int maxlen = 5;
|
||||
INDArray expected = Nd4j.create(new float[] {
|
||||
1.f, 0.f, 0.f, 0.f, 0.f,
|
||||
1.f, 1.f, 1.f, 0.f, 0.f,
|
||||
1.f, 1.f, 0.f, 0.f, 0.f
|
||||
}).reshape(3,5);
|
||||
1.f, 0.f, 0.f, 0.f, 0.f,
|
||||
1.f, 1.f, 1.f, 0.f, 0.f,
|
||||
1.f, 1.f, 0.f, 0.f, 0.f
|
||||
}).reshape(3,5);
|
||||
INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.FLOAT));
|
||||
SDVariable result1 = sameDiff.sequenceMask(lengths, maxlen, DataType.FLOAT);
|
||||
assertArrayEquals(expected.shape(), result1.eval().shape());
|
||||
|
@ -1416,6 +1483,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMeshGrid(){
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
||||
|
@ -1472,6 +1541,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGather(){
|
||||
List<INDArray> inArrs = new ArrayList<>();
|
||||
List<Integer> axis = new ArrayList<>();
|
||||
|
@ -1541,7 +1612,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGatherSimple() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGatherSimple(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2});
|
||||
SDVariable x = sameDiff.var("x", arr);
|
||||
|
@ -1551,7 +1624,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGatherNdSingle() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGatherNdSingle(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4);
|
||||
INDArray arr2 = Nd4j.create(new float[]{1, 2, 3, 0, 1, 3, 1, 0, 2}, new long[]{3, 3}).castTo(DataType.INT);
|
||||
|
@ -1563,14 +1638,16 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
for (int i=0; i<3; i++){
|
||||
INDArray idx = arr2.get(point(i), NDArrayIndex.all());
|
||||
expected.putScalar(i, arr1.get(point(idx.getInt(0)),
|
||||
point(idx.getInt(1)),
|
||||
point(idx.getInt(2))).getDouble(0));
|
||||
point(idx.getInt(1)),
|
||||
point(idx.getInt(2))).getDouble(0));
|
||||
}
|
||||
assertEquals(expected, result.eval());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStack2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStack2(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
|
||||
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
|
||||
|
@ -1581,7 +1658,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testParallelStack() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testParallelStack(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
|
||||
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(7, 12, 6)).reshape(3, 2);
|
||||
|
@ -1593,7 +1672,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testUnStack2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testUnStack2(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr1 = Nd4j.zeros(3, 2);
|
||||
INDArray arr2 = Nd4j.ones(3, 2);
|
||||
|
@ -1606,7 +1687,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPermuteSimple() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPermuteSimple(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3));
|
||||
SDVariable x = sameDiff.var("x", arr);
|
||||
|
@ -1617,7 +1700,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testConcat2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConcat2(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
|
||||
INDArray arr2 = Transforms.sigmoid(Nd4j.linspace(4, 8, 4)).reshape(1,4);
|
||||
|
@ -1628,7 +1713,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testTile2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTile2(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4));
|
||||
SDVariable x = sameDiff.var("x", arr);
|
||||
|
@ -1641,7 +1728,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSlice2d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSlice2d(Nd4jBackend backend) {
|
||||
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -1657,7 +1746,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSlice3d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSlice3d(Nd4jBackend backend) {
|
||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -1672,7 +1763,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStridedSlice2dBasic() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStridedSlice2dBasic(Nd4jBackend backend) {
|
||||
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -1690,7 +1783,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testStridedSliceBeginEndMask() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStridedSliceBeginEndMask(Nd4jBackend backend) {
|
||||
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -1705,7 +1800,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStridedSliceEllipsisMask() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStridedSliceEllipsisMask(Nd4jBackend backend) {
|
||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var("in", inArr);
|
||||
|
@ -1722,7 +1819,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStridedSliceNewAxisMask() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStridedSliceNewAxisMask(Nd4jBackend backend) {
|
||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var("in", inArr);
|
||||
|
@ -1735,7 +1834,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStridedSliceNewAxisMask2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStridedSliceNewAxisMask2(Nd4jBackend backend) {
|
||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var("in", inArr);
|
||||
|
@ -1746,7 +1847,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStridedSliceShrinkAxisMask() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStridedSliceShrinkAxisMask(Nd4jBackend backend) {
|
||||
|
||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -1763,7 +1866,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSizeAt_1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSizeAt_1(Nd4jBackend backend) {
|
||||
val array = Nd4j.create(10, 20, 30);
|
||||
val exp = Nd4j.scalar(DataType.LONG, 20);
|
||||
|
||||
|
@ -1777,6 +1882,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEye(){
|
||||
int[] rows = new int[]{3,3,3,3};
|
||||
int[] cols = new int[]{3,2,2,2};
|
||||
|
@ -1815,6 +1922,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSplit1(){
|
||||
INDArray in = Nd4j.linspace(1,10,10).reshape(10);
|
||||
INDArray axis = Nd4j.scalar(-1);
|
||||
|
@ -1833,6 +1942,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSplit2(){
|
||||
INDArray in = Nd4j.linspace(1,24,24).reshape(3,8);
|
||||
INDArray axis = Nd4j.scalar(-1);
|
||||
|
@ -1851,6 +1962,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDistancesExec(){
|
||||
//https://github.com/deeplearning4j/deeplearning4j/issues/7001
|
||||
for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) {
|
||||
|
@ -1906,6 +2019,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReductionShape(){
|
||||
|
||||
INDArray shape = Nd4j.createFromArray(4,2);
|
||||
|
@ -1924,6 +2039,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void gatherTest(){
|
||||
INDArray in = Nd4j.createFromArray(new double[][]{
|
||||
{1,2,3,4,5},
|
||||
|
@ -1943,6 +2060,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSliceShape(){
|
||||
|
||||
INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT);
|
||||
|
@ -1964,6 +2083,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testWhereAllFalse(){
|
||||
INDArray in = Nd4j.create(DataType.BOOL, 1917);
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("Where")
|
||||
|
@ -1978,6 +2099,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGatherScalar(){
|
||||
INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100);
|
||||
INDArray indices = Nd4j.scalar(0);
|
||||
|
@ -2002,6 +2125,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCastEmpty(){
|
||||
INDArray emptyLong = Nd4j.empty(DataType.LONG);
|
||||
int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h
|
||||
|
@ -2018,6 +2143,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGatherEmpty(){
|
||||
/*
|
||||
tf.reset_default_graph()
|
||||
|
@ -2050,6 +2177,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSplitEmpty(){
|
||||
/*
|
||||
tf.reset_default_graph()
|
||||
|
@ -2087,6 +2216,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConcatEmpty(){
|
||||
/*
|
||||
TF behaviour with concatenatioun of empty arrays:
|
||||
|
@ -2136,6 +2267,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConcatEmpty2(){
|
||||
INDArray empty10a = Nd4j.create(DataType.INT, 1, 0);
|
||||
INDArray empty10b = Nd4j.create(DataType.INT, 1, 0);
|
||||
|
@ -2168,6 +2301,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEmptyGather(){
|
||||
/*
|
||||
tf.reset_default_graph()
|
||||
|
@ -2200,6 +2335,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBroadcastDynamicShape1(){
|
||||
|
||||
//Test case: [2,1] and [4]: expect [2,4]
|
||||
|
@ -2221,6 +2358,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBroadcastDynamicShape2(){
|
||||
|
||||
//Test case: [2,1,4] and [2,2,4]: expect [2,2,4]
|
||||
|
@ -2243,6 +2382,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStridedSliceShrinkAxis(){
|
||||
INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2);
|
||||
INDArray begin = Nd4j.createFromArray(2);
|
||||
|
@ -2268,6 +2409,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStridedSliceEmpty(){
|
||||
|
||||
INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask!
|
||||
|
@ -2290,6 +2433,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStridedSliceEdgeCase(){
|
||||
INDArray in = Nd4j.scalar(10).reshape(1); //Int [1]
|
||||
INDArray begin = Nd4j.ones(DataType.INT, 1);
|
||||
|
@ -2315,6 +2460,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEmptySlice1(){
|
||||
INDArray in = Nd4j.createFromArray(38);
|
||||
INDArray begin = Nd4j.createFromArray(1);
|
||||
|
@ -2334,6 +2481,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEmptySlice2(){
|
||||
INDArray in = Nd4j.createFromArray(38);
|
||||
INDArray begin = Nd4j.createFromArray(0);
|
||||
|
@ -2353,6 +2502,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testFill(){
|
||||
|
||||
INDArray shape = Nd4j.createFromArray(0,4);
|
||||
|
@ -2372,6 +2523,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testFill2(){
|
||||
|
||||
INDArray shape = Nd4j.createFromArray(0,4);
|
||||
|
@ -2389,6 +2542,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPermuteShapeDynamicAxis(){
|
||||
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("permute")
|
||||
|
@ -2418,6 +2573,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGather2(){
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3));
|
||||
|
@ -2437,6 +2594,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPermute3(){
|
||||
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
|
||||
INDArray permute = Nd4j.createFromArray(1,0);
|
||||
|
@ -2455,6 +2614,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPermute4(){
|
||||
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
|
||||
INDArray permute = Nd4j.createFromArray(1,0);
|
||||
|
@ -2485,6 +2646,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInvertPermutation(){
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation")
|
||||
.addInputs(Nd4j.createFromArray(1, 0))
|
||||
|
@ -2492,7 +2655,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBroadcastInt1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBroadcastInt1(Nd4jBackend backend) {
|
||||
|
||||
INDArray out = Nd4j.create(DataType.INT, 1);
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
|
||||
|
@ -2505,6 +2670,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBroadcastInt2(){
|
||||
INDArray out = Nd4j.create(DataType.INT, 2);
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
|
||||
|
@ -2544,7 +2711,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testMergeMaxIndex() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMergeMaxIndex(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -2561,7 +2730,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testTriOp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTriOp(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable out = new Tri(sd, DataType.INT32, 3, 5, 2).outputVariable();
|
||||
|
@ -2573,7 +2744,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testTriuOp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTriuOp(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {7,8,9},{10,11,12}}));
|
||||
|
@ -2581,8 +2754,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
out.markAsLoss();
|
||||
INDArray expected = Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {0,8,9},{0,0,12}});
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.expectedOutput("triu", expected)
|
||||
.gradientCheck(true));
|
||||
.expectedOutput("triu", expected)
|
||||
.gradientCheck(true));
|
||||
assertNull(err);
|
||||
|
||||
}
|
||||
|
|
|
@ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach;
|
|||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.OpValidationSuite;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -94,9 +96,6 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
private DataType initialType;
|
||||
|
||||
public TransformOpValidation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void before() {
|
||||
|
@ -120,7 +119,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testScalarOps() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testScalarOps(Nd4jBackend backend) {
|
||||
int d0 = 2;
|
||||
int d1 = 3;
|
||||
int d2 = 4;
|
||||
|
@ -217,7 +218,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testScalarMulCF() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testScalarMulCF(Nd4jBackend backend) {
|
||||
|
||||
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
|
||||
INDArray outC = Nd4j.createUninitialized(3, 4);
|
||||
|
@ -231,7 +234,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testScalarMulCF2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testScalarMulCF2(Nd4jBackend backend) {
|
||||
|
||||
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
|
||||
|
||||
|
@ -242,7 +247,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCross() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCross(Nd4jBackend backend) {
|
||||
INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3});
|
||||
INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3});
|
||||
|
||||
|
@ -270,7 +277,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSpaceToDepth() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSpaceToDepth(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(1337);
|
||||
|
||||
int miniBatch = 128;
|
||||
|
@ -298,7 +307,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDepthToSpace() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDepthToSpace(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(1337);
|
||||
|
||||
int miniBatch = 128;
|
||||
|
@ -325,7 +336,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBatchToSpace() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBatchToSpace(Nd4jBackend backend) {
|
||||
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
|
||||
Nd4j.getRandom().setSeed(1337);
|
||||
|
||||
|
@ -362,7 +375,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSpaceToBatch() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSpaceToBatch(Nd4jBackend backend) {
|
||||
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
|
||||
|
||||
Nd4j.getRandom().setSeed(7331);
|
||||
|
@ -400,7 +415,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDynamicPartition() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDynamicPartition(Nd4jBackend backend) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
INDArray ia = Nd4j.create(new double[]{4, 3, 5, 7, 8, 0});
|
||||
|
@ -440,7 +457,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDynamicPartition2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDynamicPartition2(Nd4jBackend backend) {
|
||||
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
|
||||
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
|
||||
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
|
||||
|
@ -458,7 +477,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDynamicStitch() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDynamicStitch(Nd4jBackend backend) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
INDArray ia = Nd4j.create(new double[]{5, 1, 3}, new long[]{3});
|
||||
|
@ -495,7 +516,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDiag() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDiag(Nd4jBackend backend) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
INDArray ia = Nd4j.create(new double[]{1, 2}, new int[]{2});
|
||||
|
@ -521,7 +544,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDiagPart() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDiagPart(Nd4jBackend backend) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
INDArray input = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
|
||||
|
@ -540,7 +565,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEye() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEye(Nd4jBackend backend) {
|
||||
int[] rows = new int[]{3, 3, 3, 3};
|
||||
int[] cols = new int[]{3, 2, 2, 2};
|
||||
int[][] batch = new int[][]{{}, {}, {4}, {3, 3}};
|
||||
|
@ -574,7 +601,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEyeShape() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEyeShape(Nd4jBackend backend) {
|
||||
DynamicCustomOp dco = DynamicCustomOp.builder("eye")
|
||||
.addIntegerArguments(3, 3)
|
||||
//.addIntegerArguments(-99,3,3) //Also fails
|
||||
|
@ -586,7 +615,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testTransforms() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTransforms(Nd4jBackend backend) {
|
||||
//Test transforms (non-pairwise)
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -1074,7 +1105,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testPairwiseTransforms() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPairwiseTransforms(Nd4jBackend backend) {
|
||||
/*
|
||||
add, sub, mul, div, rsub, rdiv
|
||||
eq, neq, gt, lt, gte, lte, or, and, xor
|
||||
|
@ -1258,7 +1291,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testIsX() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIsX(Nd4jBackend backend) {
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
|
@ -1313,7 +1348,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReplaceWhereScalar() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReplaceWhereScalar(Nd4jBackend backend) {
|
||||
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
|
||||
|
||||
log.info("Testing condition: " + c.getClass().getSimpleName());
|
||||
|
@ -1335,7 +1372,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReplaceWhereArray() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReplaceWhereArray(Nd4jBackend backend) {
|
||||
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
|
||||
|
||||
INDArray inArr = Nd4j.rand(3, 4);
|
||||
|
@ -1358,7 +1397,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLogGrad() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLogGrad(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE));
|
||||
SDVariable log = sameDiff.math().log(input);
|
||||
|
@ -1369,7 +1410,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSigmoidBackwards() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSigmoidBackwards(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
Map<String, INDArray> inputs = new HashMap<>();
|
||||
|
@ -1386,8 +1429,10 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
|
||||
/* @Test
|
||||
public void testDepth() {
|
||||
/* @Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDepth(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
SDVariable x = sameDiff.one("one",new long[]{2,2});
|
||||
assertEquals(0,x.depth());
|
||||
|
@ -1396,7 +1441,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}*/
|
||||
|
||||
@Test
|
||||
public void testRank0EdgeCase() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRank0EdgeCase(Nd4jBackend backend) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4})));
|
||||
double d0 = v1.eval().getDouble(0);
|
||||
|
@ -1409,7 +1456,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAtan2BroadcastShape() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAtan2BroadcastShape(Nd4jBackend backend) {
|
||||
INDArray arr1 = Nd4j.create(new long[]{3, 1, 4});
|
||||
INDArray arr2 = Nd4j.create(new long[]{1, 2, 4});
|
||||
|
||||
|
@ -1424,7 +1473,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBooleanAnd() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBooleanAnd(Nd4jBackend backend) {
|
||||
Nd4j.setDataType(DataType.FLOAT);
|
||||
INDArray arr1 = Nd4j.create(new long[]{3, 4});
|
||||
INDArray arr2 = Nd4j.create(new long[]{3, 4});
|
||||
|
@ -1438,7 +1489,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testScatterOpsScalar() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testScatterOpsScalar(Nd4jBackend backend) {
|
||||
for (String s : new String[]{"add", "sub", "mul", "div"}) {
|
||||
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
|
||||
INDArray indices = Nd4j.scalar(5);
|
||||
|
@ -1483,7 +1536,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
@Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540")
|
||||
@Test
|
||||
public void testPad() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPad(Nd4jBackend backend) {
|
||||
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
|
||||
INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG);
|
||||
INDArray value = Nd4j.scalar(10.0);
|
||||
|
@ -1510,7 +1565,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testMirrorPad() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMirrorPad(Nd4jBackend backend) {
|
||||
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
||||
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
|
||||
|
||||
|
@ -1543,7 +1600,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMirrorPad2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMirrorPad2(Nd4jBackend backend) {
|
||||
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
||||
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
|
||||
|
||||
|
@ -1569,7 +1628,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMirrorPadSymmetric() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMirrorPadSymmetric(Nd4jBackend backend) {
|
||||
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4);
|
||||
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT);
|
||||
|
||||
|
@ -1596,7 +1657,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testUnique() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testUnique(Nd4jBackend backend) {
|
||||
INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4});
|
||||
|
||||
INDArray expUnique = Nd4j.create(new double[]{3, 4, 1, 0, 2});
|
||||
|
@ -1618,7 +1681,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testTopK() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTopK(Nd4jBackend backend) {
|
||||
OpValidationSuite.ignoreFailing(); //Can't assume sorted here
|
||||
INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8});
|
||||
|
||||
|
@ -1647,7 +1712,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testTopK1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTopK1(Nd4jBackend backend) {
|
||||
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
|
||||
INDArray k = Nd4j.scalar(1);
|
||||
INDArray outValue = Nd4j.create(DataType.DOUBLE, 1);
|
||||
|
@ -1668,7 +1735,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testInTopK() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInTopK(Nd4jBackend backend) {
|
||||
for (int k = 4; k >= 1; k--) {
|
||||
log.info("Testing: k=" + k);
|
||||
INDArray in = Nd4j.linspace(1, 20, 20, DataType.DOUBLE).reshape(4, 5);
|
||||
|
@ -1709,7 +1778,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testZeta() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testZeta(Nd4jBackend backend) {
|
||||
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182
|
||||
INDArray x = Nd4j.rand(3, 4).addi(1.0);
|
||||
INDArray q = Nd4j.rand(3, 4);
|
||||
|
@ -1726,7 +1797,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMaxEmptyScalar() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMaxEmptyScalar(Nd4jBackend backend) {
|
||||
INDArray empty = Nd4j.empty(DataType.FLOAT);
|
||||
INDArray scalar = Nd4j.scalar(1.0f);
|
||||
|
||||
|
@ -1743,7 +1816,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBroadcastEmpty() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBroadcastEmpty(Nd4jBackend backend) {
|
||||
// Nd4j.getExecutioner().enableVerboseMode(true);
|
||||
// Nd4j.getExecutioner().enableDebugMode(true);
|
||||
//Check broadcast behaviour with empty arrays. The idea is to match TF import behaviour, for import
|
||||
|
@ -1833,7 +1908,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStandardize() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStandardize(Nd4jBackend backend) {
|
||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
||||
|
||||
final int[] axis = new int[]{1};
|
||||
|
@ -1854,7 +1931,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStandardizeOP() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStandardizeOP(Nd4jBackend backend) {
|
||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
||||
|
||||
final int[] axis = new int[]{1};
|
||||
|
@ -1869,7 +1948,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStandardizeNoDeviation() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStandardizeNoDeviation(Nd4jBackend backend) {
|
||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
||||
for (int i = 0; i < 4; i++) {
|
||||
random.putScalar(1, i, 7);
|
||||
|
@ -1895,7 +1976,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMatMulTensor() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMatMulTensor(Nd4jBackend backend) {
|
||||
final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5});
|
||||
final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6});
|
||||
|
||||
|
@ -1915,7 +1998,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMatMulTensorTranspose() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMatMulTensorTranspose(Nd4jBackend backend) {
|
||||
for (boolean transposeA : new boolean[]{false, true}) {
|
||||
for (boolean transposeB : new boolean[]{false, true}) {
|
||||
for (boolean transposeResult : new boolean[]{false, true}) {
|
||||
|
@ -2008,7 +2093,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSoftmaxCF() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSoftmaxCF(Nd4jBackend backend) {
|
||||
|
||||
INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5);
|
||||
INDArray arrF = arrC.dup('f');
|
||||
|
@ -2029,7 +2116,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLogSumExp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLogSumExp(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4);
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -2044,7 +2133,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testLogSumExp2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLogSumExp2(Nd4jBackend backend) {
|
||||
|
||||
for (int dim = 0; dim <= 2; dim++) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -2065,7 +2156,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testCRELU() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCRELU(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2);
|
||||
|
@ -2084,7 +2177,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testClipByAvgNorm() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testClipByAvgNorm(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2, 2);
|
||||
|
@ -2105,7 +2200,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEmbeddingLookup() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEmbeddingLookup(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -2118,49 +2215,53 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testImageResize() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testImageResize(Nd4jBackend backend) {
|
||||
|
||||
//TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea
|
||||
|
||||
for (ImageResizeMethod method : ImageResizeMethod.values()) {
|
||||
if (method==ImageResizeMethod.ResizeLanczos5 || method==ImageResizeMethod.ResizeArea || method == ImageResizeMethod.ResizeMitchelcubic)
|
||||
{continue;}
|
||||
if (method==ImageResizeMethod.ResizeLanczos5 || method==ImageResizeMethod.ResizeArea || method == ImageResizeMethod.ResizeMitchelcubic)
|
||||
{continue;}
|
||||
|
||||
log.info("Trying {}", method);
|
||||
log.info("Trying {}", method);
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
boolean preserveAspectRatio = true;
|
||||
boolean antialias = true;
|
||||
SDVariable inputImage = sd.var(Nd4j.rand(DataType.FLOAT, 1, 5, 5, 3));
|
||||
// NHWC format
|
||||
long[] expectedShape = new long[]{1, 3, 3, 3};
|
||||
SDVariable requestedSize = sd.constant(Nd4j.createFromArray( new long[]{3, 3}));
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
boolean preserveAspectRatio = true;
|
||||
boolean antialias = true;
|
||||
SDVariable inputImage = sd.var(Nd4j.rand(DataType.FLOAT, 1, 5, 5, 3));
|
||||
// NHWC format
|
||||
long[] expectedShape = new long[]{1, 3, 3, 3};
|
||||
SDVariable requestedSize = sd.constant(Nd4j.createFromArray( new long[]{3, 3}));
|
||||
|
||||
Function<INDArray, String> checkFunction = in -> {
|
||||
boolean shapeOk = Arrays.equals(expectedShape, in.shape());
|
||||
if (shapeOk) return null;
|
||||
return "Failed: shape differs - expected " + Arrays.toString(expectedShape) + " vs " + Arrays.toString(in.shape()) + " on method " + method;
|
||||
};
|
||||
Function<INDArray, String> checkFunction = in -> {
|
||||
boolean shapeOk = Arrays.equals(expectedShape, in.shape());
|
||||
if (shapeOk) return null;
|
||||
return "Failed: shape differs - expected " + Arrays.toString(expectedShape) + " vs " + Arrays.toString(in.shape()) + " on method " + method;
|
||||
};
|
||||
|
||||
|
||||
SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable().std(true);
|
||||
SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable().std(true);
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.gradientCheck(false)
|
||||
.expected("image_resize", checkFunction));
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.gradientCheck(false)
|
||||
.expected("image_resize", checkFunction));
|
||||
|
||||
assertNull(err);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testMaximumBp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMaximumBp(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -2177,7 +2278,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMergeAddBp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMergeAddBp(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -2194,7 +2297,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMergeMaxBp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMergeMaxBp(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -2212,7 +2317,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@Test
|
||||
public void testMergeAvgBp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMergeAvgBp(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -2229,7 +2336,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReverseBp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReverseBp(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -2243,7 +2352,9 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testUpsampling3dBp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testUpsampling3dBp(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
for (boolean dataformat : new boolean[]{true, false}) {
|
||||
|
|
|
@ -24,8 +24,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||
|
@ -36,11 +37,8 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
|||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
public class ConvConfigTests extends BaseNd4jTest {
|
||||
public class ConvConfigTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
public ConvConfigTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -48,7 +46,9 @@ public class ConvConfigTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDeConv2D(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDeConv2D(Nd4jBackend backend){
|
||||
DeConv2DConfig.builder().kH(2).kW(4).build();
|
||||
|
||||
try{
|
||||
|
@ -108,8 +108,10 @@ public class ConvConfigTests extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConv2D(){
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConv2D(Nd4jBackend backend){
|
||||
Conv2DConfig.builder().kH(2).kW(4).build();
|
||||
|
||||
try{
|
||||
|
@ -169,8 +171,10 @@ public class ConvConfigTests extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPooling2D(){
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPooling2D(Nd4jBackend backend){
|
||||
Pooling2DConfig.builder().kH(2).kW(4).build();
|
||||
|
||||
try{
|
||||
|
@ -230,8 +234,10 @@ public class ConvConfigTests extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeConv3D(){
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDeConv3D(Nd4jBackend backend){
|
||||
DeConv3DConfig.builder().kH(2).kW(4).kD(3).build();
|
||||
|
||||
try{
|
||||
|
@ -319,8 +325,10 @@ public class ConvConfigTests extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConv3D(){
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConv3D(Nd4jBackend backend){
|
||||
Conv3DConfig.builder().kH(2).kW(4).kD(3).build();
|
||||
|
||||
try{
|
||||
|
@ -410,8 +418,10 @@ public class ConvConfigTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
|
||||
@Test
|
||||
public void testPooling3D(){
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPooling3D(Nd4jBackend backend){
|
||||
Pooling3DConfig.builder().kH(2).kW(4).kD(3).build();
|
||||
|
||||
try{
|
||||
|
@ -499,7 +509,9 @@ public class ConvConfigTests extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConv1D(){
|
||||
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();
|
||||
|
||||
|
|
|
@ -23,8 +23,10 @@ package org.nd4j.autodiff.samediff;
|
|||
import lombok.val;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.OpValidationSuite;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Disabled("AB 2019/05/21 - JVM Crash on ppc64 - Issue #7657")
|
||||
public class FailingSameDiffTests extends BaseNd4jTest {
|
||||
public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
public FailingSameDiffTests(Nd4jBackend b){
|
||||
super(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
|
@ -52,7 +51,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEye(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEye(Nd4jBackend backend){
|
||||
//OpValidationSuite.ignoreFailing();
|
||||
INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3});
|
||||
List<INDArray> stack = new ArrayList<>();
|
||||
|
@ -68,7 +69,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEyeShape(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEyeShape(Nd4jBackend backend){
|
||||
val dco = DynamicCustomOp.builder("eye")
|
||||
.addIntegerArguments(3,3)
|
||||
//.addIntegerArguments(-99,3,3) //Also fails
|
||||
|
@ -80,7 +83,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testExecutionDifferentShapesTransform(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testExecutionDifferentShapesTransform(Nd4jBackend backend){
|
||||
OpValidationSuite.ignoreFailing();
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var("in", Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4));
|
||||
|
@ -101,7 +106,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDropout() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDropout(Nd4jBackend backend) {
|
||||
OpValidationSuite.ignoreFailing();
|
||||
SameDiff sd = SameDiff.create();
|
||||
double p = 0.5;
|
||||
|
@ -114,7 +121,9 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testExecutionDifferentShapesDynamicCustom(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testExecutionDifferentShapesDynamicCustom(Nd4jBackend backend){
|
||||
OpValidationSuite.ignoreFailing();
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
|
|
@ -26,13 +26,15 @@ import org.apache.commons.io.IOUtils;
|
|||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.graph.FlatConfiguration;
|
||||
import org.nd4j.graph.FlatGraph;
|
||||
import org.nd4j.graph.FlatNode;
|
||||
import org.nd4j.graph.FlatVariable;
|
||||
import org.nd4j.graph.IntPair;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||
|
@ -70,11 +72,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Slf4j
|
||||
public class FlatBufferSerdeTest extends BaseNd4jTest {
|
||||
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public FlatBufferSerdeTest(Nd4jBackend b){
|
||||
super(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
|
@ -84,7 +83,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testBasic(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
SameDiff sd = SameDiff.create();
|
||||
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
|
||||
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
|
||||
|
@ -139,7 +140,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSimple(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
for( int i = 0; i < 10; i++ ) {
|
||||
for(boolean execFirst : new boolean[]{false, true}) {
|
||||
log.info("Starting test: i={}, execFirst={}", i, execFirst);
|
||||
|
@ -268,7 +271,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testTrainingSerde(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
|
||||
//Ensure 2 things:
|
||||
//1. Training config is serialized/deserialized correctly
|
||||
|
@ -352,7 +357,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void pooling3DSerialization(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void pooling3DSerialization(Nd4jBackend backend){
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28);
|
||||
|
@ -372,7 +379,9 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void pooling3DSerialization2(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void pooling3DSerialization2(Nd4jBackend backend){
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
SDVariable x = sd.placeHolder("x", DataType.FLOAT, 1, 28, 28);
|
||||
|
|
|
@ -22,12 +22,14 @@ package org.nd4j.autodiff.samediff;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
|
||||
import org.nd4j.autodiff.samediff.transform.OpPredicate;
|
||||
import org.nd4j.autodiff.samediff.transform.SubGraph;
|
||||
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
|
||||
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||
|
@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
|||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Slf4j
|
||||
public class GraphTransformUtilTests extends BaseNd4jTest {
|
||||
public class GraphTransformUtilTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
public GraphTransformUtilTests(Nd4jBackend b){
|
||||
super(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
|
@ -54,7 +53,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBasic(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBasic(Nd4jBackend backend){
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 32);
|
||||
|
@ -93,7 +94,9 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSubgraphReplace1(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSubgraphReplace1(Nd4jBackend backend){
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, -1, 4);
|
||||
|
|
|
@ -21,8 +21,10 @@
|
|||
package org.nd4j.autodiff.samediff;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -32,11 +34,8 @@ import java.lang.reflect.Field;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class MemoryMgrTest extends BaseNd4jTest {
|
||||
public class MemoryMgrTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public MemoryMgrTest(Nd4jBackend b){
|
||||
super(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
|
@ -44,7 +43,9 @@ public class MemoryMgrTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testArrayReuseTooLarge() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testArrayReuseTooLarge(Nd4jBackend backend) throws Exception {
|
||||
|
||||
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
|
||||
Field f = ArrayCacheMemoryMgr.class.getDeclaredField("maxCacheBytes");
|
||||
|
@ -97,7 +98,7 @@ public class MemoryMgrTest extends BaseNd4jTest {
|
|||
assertEquals(10, mmgr.getLruCacheValues().size());
|
||||
|
||||
//now, allocate some values:
|
||||
for( int i=1; i<=10; i++ ) {
|
||||
for( int i = 1; i <= 10; i++) {
|
||||
INDArray a1 = mmgr.allocate(true, DataType.FLOAT, 25);
|
||||
assertEquals(1000 - i * 100, mmgr.getCurrentCacheSize());
|
||||
assertEquals(1000 - i * 100, as.getBytesSum());
|
||||
|
@ -116,10 +117,12 @@ public class MemoryMgrTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testManyArrays(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testManyArrays(Nd4jBackend backend){
|
||||
|
||||
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
|
||||
for( int i=0; i<1000; i++ ){
|
||||
for( int i = 0; i < 1000; i++) {
|
||||
mmgr.release(Nd4j.scalar(0));
|
||||
}
|
||||
|
||||
|
@ -127,7 +130,7 @@ public class MemoryMgrTest extends BaseNd4jTest {
|
|||
assertEquals(1000, mmgr.getLruCache().size());
|
||||
assertEquals(1000, mmgr.getLruCacheValues().size());
|
||||
|
||||
for( int i=0; i<1000; i++ ){
|
||||
for( int i = 0; i < 1000; i++ ){
|
||||
mmgr.release(Nd4j.scalar(0));
|
||||
}
|
||||
|
||||
|
|
|
@ -21,9 +21,11 @@
|
|||
package org.nd4j.autodiff.samediff;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
|
@ -35,19 +37,18 @@ import java.util.Set;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class NameScopeTests extends BaseNd4jTest {
|
||||
public class NameScopeTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
public NameScopeTests(Nd4jBackend b){
|
||||
super(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVariableNameScopesBasic(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testVariableNameScopesBasic(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable v = sd.var("x");
|
||||
|
@ -73,7 +74,9 @@ public class NameScopeTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testOpFieldsAndNames(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOpFieldsAndNames(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable x = sd.var("x", DataType.FLOAT, 1);
|
||||
|
@ -151,7 +154,9 @@ public class NameScopeTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNoNesting(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNoNesting(Nd4jBackend backend) {
|
||||
SameDiff SD = SameDiff.create();
|
||||
|
||||
SDVariable a = SD.constant(4);
|
||||
|
@ -168,7 +173,9 @@ public class NameScopeTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNoTesting2(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNoTesting2(Nd4jBackend backend) {
|
||||
SameDiff SD = SameDiff.create();
|
||||
|
||||
SDVariable a = SD.constant(4);
|
||||
|
|
|
@ -21,21 +21,16 @@
|
|||
package org.nd4j.autodiff.samediff;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.common.primitives.AtomicBoolean;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.imports.tfgraphs.TFGraphTestZooModels;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.AtomicBoolean;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Collections;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.Semaphore;
|
||||
|
@ -55,7 +50,9 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSimple() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSimple(Nd4jBackend backend) throws Exception {
|
||||
|
||||
int nThreads = 4;
|
||||
int nRuns = 1000;
|
||||
|
@ -103,48 +100,6 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled //2020/03/24 AB - https://github.com/eclipse/deeplearning4j/issues/8802
|
||||
public void testMobilenet(@TempDir Path testDir) throws Exception {
|
||||
TFGraphTestZooModels.currentTestDir = testDir.toFile();
|
||||
File f = Resources.asFile("tf_graphs/zoo_models/mobilenet_v2_1.0_224/tf_model.txt");
|
||||
SameDiff sd = TFGraphTestZooModels.LOADER.apply(f, "mobilenet_v2_1.0_224");
|
||||
// System.out.println(sd.summary());
|
||||
|
||||
int nThreads = 4;
|
||||
int nRuns = 30;
|
||||
INDArray[] inputArrs = new INDArray[nThreads];
|
||||
INDArray[] expOut = new INDArray[nThreads];
|
||||
for( int i=0; i<nThreads; i++ ){
|
||||
if(i == 0 || i > 2)
|
||||
inputArrs[i] = Nd4j.rand(DataType.FLOAT, 1, 224, 224, 3);
|
||||
else if(i == 1)
|
||||
inputArrs[i] = Nd4j.zeros(DataType.FLOAT, 1, 224, 224, 3);
|
||||
else if(i == 2)
|
||||
inputArrs[i] = Nd4j.ones(DataType.FLOAT, 1, 224, 224, 3);
|
||||
|
||||
expOut[i] = sd.outputSingle(Collections.singletonMap("input", inputArrs[i]), "MobilenetV2/Predictions/Reshape_1");
|
||||
Nd4j.getExecutioner().commit();
|
||||
}
|
||||
|
||||
AtomicBoolean[] failuresByThread = new AtomicBoolean[nThreads];
|
||||
AtomicInteger[] counters = new AtomicInteger[nThreads];
|
||||
Semaphore s = new Semaphore(nThreads);
|
||||
CountDownLatch latch = new CountDownLatch(nThreads);
|
||||
|
||||
doTest(sd, nThreads, nRuns, inputArrs, expOut, "input", "MobilenetV2/Predictions/Reshape_1", failuresByThread, counters, s, latch);
|
||||
|
||||
s.release(nThreads);
|
||||
latch.await();
|
||||
|
||||
for(int i = 0; i < nThreads; i++) {
|
||||
assertFalse( failuresByThread[i].get(),"Thread " + i + " failed");
|
||||
}
|
||||
|
||||
for(int i = 0; i < nThreads; i++) {
|
||||
assertEquals( nRuns, counters[i].get(),"Thread " + i + " number of runs");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public static void doTest(SameDiff sd, int nThreads, int nRuns, INDArray[] inputArrs, INDArray[] expOut,
|
||||
|
|
|
@ -21,7 +21,9 @@
|
|||
package org.nd4j.autodiff.samediff;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
@ -31,14 +33,13 @@ import org.nd4j.linalg.learning.config.Sgd;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class SameDiffOutputTest extends BaseNd4jTest {
|
||||
public class SameDiffOutputTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public SameDiffOutputTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void outputTest(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void outputTest(Nd4jBackend backend){
|
||||
DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10));
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
|
|
|
@ -21,7 +21,9 @@
|
|||
package org.nd4j.autodiff.samediff;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
@ -35,19 +37,18 @@ import static junit.framework.TestCase.assertNotNull;
|
|||
import static junit.framework.TestCase.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
|
||||
public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
public SameDiffSpecifiedLossVarsTests(Nd4jBackend b){
|
||||
super(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSpecifiedLoss1(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSpecifiedLoss1(Nd4jBackend backend) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4);
|
||||
ph1.setArray(Nd4j.create(DataType.FLOAT, 3, 4));
|
||||
|
@ -68,7 +69,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSpecifiedLoss2(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSpecifiedLoss2(Nd4jBackend backend) {
|
||||
for( int i=0; i<2; i++ ) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4);
|
||||
|
@ -121,7 +124,9 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testTrainingDifferentLosses(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTrainingDifferentLosses(Nd4jBackend backend) {
|
||||
//Net with 2 losses: train on the first one, then change losses
|
||||
//Also check that if modifying via add/setLossVariables the training config changes
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -30,11 +30,13 @@ import java.util.Map;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
||||
import org.nd4j.autodiff.listeners.records.History;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
@ -55,14 +57,13 @@ import org.nd4j.linalg.learning.config.Sgd;
|
|||
import org.nd4j.weightinit.impl.XavierInitScheme;
|
||||
|
||||
@Slf4j
|
||||
public class SameDiffTrainingTest extends BaseNd4jTest {
|
||||
public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public SameDiffTrainingTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void irisTrainingSanityCheck() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void irisTrainingSanityCheck(Nd4jBackend backend) {
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
NormalizerStandardize std = new NormalizerStandardize();
|
||||
|
@ -134,7 +135,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void irisTrainingEvalTest() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void irisTrainingEvalTest(Nd4jBackend backend) {
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
NormalizerStandardize std = new NormalizerStandardize();
|
||||
|
@ -184,7 +187,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void irisTrainingValidationTest() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void irisTrainingValidationTest(Nd4jBackend backend) {
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
NormalizerStandardize std = new NormalizerStandardize();
|
||||
|
@ -239,6 +244,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTrainingMixedDtypes(){
|
||||
|
||||
for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) {
|
||||
|
@ -301,7 +308,9 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void simpleClassification() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void simpleClassification(Nd4jBackend backend) {
|
||||
double learning_rate = 0.001;
|
||||
int seed = 7;
|
||||
org.nd4j.linalg.api.rng.Random rng = Nd4j.getRandom();
|
||||
|
@ -348,6 +357,8 @@ public class SameDiffTrainingTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTrainingEvalVarNotReqForLoss(){
|
||||
//If a variable is not required for the loss - normally it won't be calculated
|
||||
//But we want to make sure it IS calculated here - so we can perform evaluation on it
|
||||
|
|
|
@ -25,11 +25,13 @@ import org.junit.Assert;
|
|||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
@ -48,11 +50,8 @@ import java.util.concurrent.TimeUnit;
|
|||
import static junit.framework.TestCase.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class CheckpointListenerTest extends BaseNd4jTest {
|
||||
public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public CheckpointListenerTest(Nd4jBackend backend){
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
|
@ -96,7 +95,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testCheckpointEveryEpoch(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
|
||||
SameDiff sd = getModel();
|
||||
|
@ -130,7 +131,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCheckpointEvery5Iter(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
|
||||
SameDiff sd = getModel();
|
||||
|
@ -169,7 +172,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
SameDiff sd = getModel();
|
||||
|
||||
|
@ -199,7 +204,7 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
|||
for(File f : files){
|
||||
String s = f.getAbsolutePath();
|
||||
// System.out.println(s);
|
||||
for( int i=0; i<names.size(); i++ ){
|
||||
for( int i = 0; i < names.size(); i++ ){
|
||||
if(s.contains(names.get(i))){
|
||||
found[i] = true;
|
||||
break;
|
||||
|
@ -213,7 +218,9 @@ public class CheckpointListenerTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
SameDiff sd = getModel();
|
||||
|
||||
|
|
|
@ -21,12 +21,13 @@
|
|||
package org.nd4j.autodiff.samediff.listeners;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener;
|
||||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
@ -34,14 +35,13 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
|
||||
public class ExecDebuggingListenerTest extends BaseNd4jTest {
|
||||
public class ExecDebuggingListenerTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public ExecDebuggingListenerTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExecDebugListener(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testExecDebugListener(Nd4jBackend backend) {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
package org.nd4j.autodiff.samediff.listeners;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.Listener;
|
||||
|
@ -38,7 +40,7 @@ import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
|||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.Evaluation.Metric;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.OpContext;
|
||||
|
@ -61,11 +63,8 @@ import java.util.Map;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class ListenerTest extends BaseNd4jTest {
|
||||
public class ListenerTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public ListenerTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -73,7 +72,9 @@ public class ListenerTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void irisHistoryTest() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void irisHistoryTest(Nd4jBackend backend) {
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
NormalizerStandardize std = new NormalizerStandardize();
|
||||
|
@ -136,6 +137,8 @@ public class ListenerTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testListenerCalls(){
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
||||
|
@ -273,7 +276,9 @@ public class ListenerTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCustomListener() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCustomListener(Nd4jBackend backend) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
|
||||
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
|
||||
|
|
|
@ -26,11 +26,13 @@ import org.apache.commons.lang3.StringUtils;
|
|||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
|
||||
import org.nd4j.autodiff.listeners.profiler.comparison.ProfileAnalyzer;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -45,11 +47,8 @@ import java.util.Map;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
|
||||
public class ProfilingListenerTest extends BaseNd4jTest {
|
||||
public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public ProfilingListenerTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -59,7 +58,9 @@ public class ProfilingListenerTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testProfilingListenerSimple(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
|
||||
|
@ -107,19 +108,25 @@ public class ProfilingListenerTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
/*
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLoadTfProfile(){
|
||||
File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json");
|
||||
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLoadTfProfileDir(){
|
||||
File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles");
|
||||
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLoadTfProfileDir2(){
|
||||
File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0");
|
||||
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
||||
|
|
|
@ -27,6 +27,8 @@ import org.junit.jupiter.api.BeforeEach;
|
|||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
|
@ -41,7 +43,7 @@ import org.nd4j.graph.UIInfoType;
|
|||
import org.nd4j.graph.UIOp;
|
||||
import org.nd4j.graph.UIVariable;
|
||||
import org.nd4j.graph.ui.LogFileWriter;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -60,11 +62,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Slf4j
|
||||
public class FileReadWriteTests extends BaseNd4jTest {
|
||||
public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
public FileReadWriteTests(Nd4jBackend b){
|
||||
super(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
|
@ -81,7 +80,9 @@ public class FileReadWriteTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSimple(@TempDir Path testDir) throws IOException {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
|
||||
SDVariable sum = v.sum();
|
||||
|
@ -185,7 +186,9 @@ public class FileReadWriteTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNullBinLabels(@TempDir Path testDir) throws Exception{
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
|
||||
File dir = testDir.toFile();
|
||||
File f = new File(dir, "temp.bin");
|
||||
LogFileWriter w = new LogFileWriter(f);
|
||||
|
|
|
@ -25,6 +25,8 @@ import com.google.flatbuffers.Table;
|
|||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.listeners.impl.UIListener;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -34,7 +36,7 @@ import org.nd4j.graph.UIEvent;
|
|||
import org.nd4j.graph.UIGraphStructure;
|
||||
import org.nd4j.graph.UIStaticInfoRecord;
|
||||
import org.nd4j.graph.ui.LogFileWriter;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
||||
|
@ -51,11 +53,8 @@ import java.util.Map;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class UIListenerTest extends BaseNd4jTest {
|
||||
public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public UIListenerTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -65,7 +64,9 @@ public class UIListenerTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testUIListenerBasic(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
|
@ -101,7 +102,9 @@ public class UIListenerTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testUIListenerContinue(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
|
||||
SameDiff sd1 = getSimpleNet();
|
||||
|
@ -192,7 +195,9 @@ public class UIListenerTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testUIListenerBadContinue(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
SameDiff sd1 = getSimpleNet();
|
||||
|
||||
|
|
|
@ -23,18 +23,17 @@ package org.nd4j.evaluation;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.custom.CustomEvaluation;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
public class CustomEvaluationTest extends BaseNd4jTest {
|
||||
public class CustomEvaluationTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public CustomEvaluationTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -42,8 +41,10 @@ public class CustomEvaluationTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void customEvalTest(){
|
||||
CustomEvaluation accuracyEval = new CustomEvaluation<Pair<Number, Long>>(
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void customEvalTest(Nd4jBackend backend){
|
||||
CustomEvaluation accuracyEval = new CustomEvaluation<>(
|
||||
(labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)),
|
||||
CustomEvaluation.mergeConcatenate());
|
||||
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.EvaluationBinary;
|
||||
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||
|
@ -29,25 +31,24 @@ import org.nd4j.evaluation.classification.ROCBinary;
|
|||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
public class EmptyEvaluationTests extends BaseNd4jTest {
|
||||
public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
public EmptyEvaluationTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyEvaluation() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEmptyEvaluation (Nd4jBackend backend) {
|
||||
Evaluation e = new Evaluation();
|
||||
System.out.println(e.stats());
|
||||
|
||||
|
@ -62,7 +63,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyRegressionEvaluation() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEmptyRegressionEvaluation (Nd4jBackend backend) {
|
||||
RegressionEvaluation re = new RegressionEvaluation();
|
||||
re.stats();
|
||||
|
||||
|
@ -76,7 +79,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyEvaluationBinary() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEmptyEvaluationBinary(Nd4jBackend backend) {
|
||||
EvaluationBinary eb = new EvaluationBinary();
|
||||
eb.stats();
|
||||
|
||||
|
@ -91,7 +96,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyROC() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEmptyROC(Nd4jBackend backend) {
|
||||
ROC roc = new ROC();
|
||||
roc.stats();
|
||||
|
||||
|
@ -106,7 +113,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyROCBinary() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEmptyROCBinary(Nd4jBackend backend) {
|
||||
ROCBinary rb = new ROCBinary();
|
||||
rb.stats();
|
||||
|
||||
|
@ -121,7 +130,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyROCMultiClass() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEmptyROCMultiClass(Nd4jBackend backend) {
|
||||
ROCMultiClass r = new ROCMultiClass();
|
||||
r.stats();
|
||||
|
||||
|
@ -136,7 +147,9 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyEvaluationCalibration() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEmptyEvaluationCalibration(Nd4jBackend backend) {
|
||||
EvaluationCalibration ec = new EvaluationCalibration();
|
||||
ec.stats();
|
||||
|
||||
|
|
|
@ -21,9 +21,11 @@
|
|||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.EvaluationBinary;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
|
||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||
|
@ -36,11 +38,8 @@ import java.util.Random;
|
|||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class EvalCustomThreshold extends BaseNd4jTest {
|
||||
public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
|
||||
|
||||
public EvalCustomThreshold(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -48,7 +47,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationCustomBinaryThreshold() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationCustomBinaryThreshold(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
//Sanity checks: 0.5 threshold for 1-output and 2-output binary cases
|
||||
|
@ -114,7 +115,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationCostArray() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationCostArray(Nd4jBackend backend) {
|
||||
|
||||
|
||||
int nExamples = 20;
|
||||
|
@ -162,7 +165,9 @@ public class EvalCustomThreshold extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationBinaryCustomThreshold() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationBinaryCustomThreshold(Nd4jBackend backend) {
|
||||
|
||||
//Sanity check: same results for 0.5 threshold vs. default (no threshold)
|
||||
int nExamples = 20;
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.EvaluationBinary;
|
||||
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||
|
@ -31,7 +33,7 @@ import org.nd4j.evaluation.curves.Histogram;
|
|||
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
|
||||
import org.nd4j.evaluation.curves.RocCurve;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -42,11 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
|
||||
public class EvalJsonTest extends BaseNd4jTest {
|
||||
public class EvalJsonTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public EvalJsonTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -54,7 +53,9 @@ public class EvalJsonTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSerdeEmpty() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerdeEmpty(Nd4jBackend backend) {
|
||||
boolean print = false;
|
||||
|
||||
IEvaluation[] arr = new IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10),
|
||||
|
@ -73,8 +74,10 @@ public class EvalJsonTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerde() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerde(Nd4jBackend backend) {
|
||||
boolean print = false;
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -121,8 +124,10 @@ public class EvalJsonTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerdeExactRoc() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerdeExactRoc(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
boolean print = false;
|
||||
|
||||
|
@ -199,8 +204,10 @@ public class EvalJsonTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJsonYamlCurves() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testJsonYamlCurves(Nd4jBackend backend) {
|
||||
ROC roc = new ROC(0);
|
||||
|
||||
INDArray evalLabel =
|
||||
|
@ -251,8 +258,10 @@ public class EvalJsonTest extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJsonWithCustomThreshold() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testJsonWithCustomThreshold(Nd4jBackend backend) {
|
||||
|
||||
//Evaluation - binary threshold
|
||||
Evaluation e = new Evaluation(0.25);
|
||||
|
|
|
@ -21,8 +21,10 @@
|
|||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
@ -39,11 +41,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
|
||||
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||
|
||||
public class EvalTest extends BaseNd4jTest {
|
||||
public class EvalTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public EvalTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -52,7 +51,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testEval() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEval(Nd4jBackend backend) {
|
||||
int classNum = 5;
|
||||
Evaluation eval = new Evaluation (classNum);
|
||||
|
||||
|
@ -91,7 +92,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEval2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEval2(Nd4jBackend backend) {
|
||||
|
||||
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
||||
Evaluation first = null;
|
||||
|
@ -150,7 +153,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStringListLabels() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStringListLabels(Nd4jBackend backend) {
|
||||
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
||||
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
||||
|
||||
|
@ -167,7 +172,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testStringHashLabels() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testStringHashLabels(Nd4jBackend backend) {
|
||||
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
||||
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
||||
|
||||
|
@ -184,7 +191,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvalMasking() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvalMasking(Nd4jBackend backend) {
|
||||
int miniBatch = 5;
|
||||
int nOut = 3;
|
||||
int tsLength = 6;
|
||||
|
@ -251,7 +260,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testFalsePerfectRecall() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testFalsePerfectRecall(Nd4jBackend backend) {
|
||||
int testSize = 100;
|
||||
int numClasses = 5;
|
||||
int winner = 1;
|
||||
|
@ -284,7 +295,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationMerging() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationMerging(Nd4jBackend backend) {
|
||||
|
||||
int nRows = 20;
|
||||
int nCols = 3;
|
||||
|
@ -358,7 +371,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSingleClassBinaryClassification() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSingleClassBinaryClassification(Nd4jBackend backend) {
|
||||
|
||||
Evaluation eval = new Evaluation(1);
|
||||
|
||||
|
@ -387,7 +402,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvalInvalid() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvalInvalid(Nd4jBackend backend) {
|
||||
Evaluation e = new Evaluation(5);
|
||||
e.eval(0, 1);
|
||||
e.eval(1, 0);
|
||||
|
@ -400,7 +417,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvalMethods() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvalMethods(Nd4jBackend backend) {
|
||||
//Check eval(int,int) vs. eval(INDArray,INDArray)
|
||||
|
||||
Evaluation e1 = new Evaluation(4);
|
||||
|
@ -443,7 +462,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testTopNAccuracy() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTopNAccuracy(Nd4jBackend backend) {
|
||||
|
||||
Evaluation e = new Evaluation(null, 3);
|
||||
|
||||
|
@ -504,7 +525,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testTopNAccuracyMerging() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTopNAccuracyMerging(Nd4jBackend backend) {
|
||||
|
||||
Evaluation e1 = new Evaluation(null, 3);
|
||||
Evaluation e2 = new Evaluation(null, 3);
|
||||
|
@ -552,7 +575,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBinaryCase() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBinaryCase(Nd4jBackend backend) {
|
||||
INDArray ones10 = Nd4j.ones(10, 1);
|
||||
INDArray ones4 = Nd4j.ones(4, 1);
|
||||
INDArray zeros4 = Nd4j.zeros(4, 1);
|
||||
|
@ -581,7 +606,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testF1FBeta_MicroMacroAveraging() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testF1FBeta_MicroMacroAveraging(Nd4jBackend backend) {
|
||||
//Confusion matrix: rows = actual, columns = predicted
|
||||
//[3, 1, 0]
|
||||
//[2, 2, 1]
|
||||
|
@ -722,7 +749,9 @@ public class EvalTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testConfusionMatrixStats() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConfusionMatrixStats(Nd4jBackend backend) {
|
||||
|
||||
Evaluation e = new Evaluation();
|
||||
|
||||
|
@ -743,6 +772,8 @@ public class EvalTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvalBinaryMetrics(){
|
||||
|
||||
Evaluation ePosClass1_nOut2 = new Evaluation(2, 1);
|
||||
|
@ -864,6 +895,8 @@ public class EvalTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConfusionMatrixString(){
|
||||
|
||||
Evaluation e = new Evaluation(Arrays.asList("a","b","c"));
|
||||
|
@ -914,6 +947,8 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationNaNs(){
|
||||
|
||||
Evaluation e = new Evaluation();
|
||||
|
@ -929,6 +964,8 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSegmentation(){
|
||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -1023,6 +1060,8 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLabelReset(){
|
||||
|
||||
Map<Integer,String> m = new HashMap<>();
|
||||
|
@ -1056,6 +1095,8 @@ public class EvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvalStatsBinaryCase(){
|
||||
//Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case
|
||||
|
||||
|
|
|
@ -21,9 +21,11 @@
|
|||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.EvaluationBinary;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -38,11 +40,8 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.nd4j.evaluation.classification.EvaluationBinary.Metric.*;
|
||||
public class EvaluationBinaryTest extends BaseNd4jTest {
|
||||
public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public EvaluationBinaryTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -50,7 +49,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationBinary() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationBinary(Nd4jBackend backend) {
|
||||
//Compare EvaluationBinary to Evaluation class
|
||||
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
||||
EvaluationBinary first = null;
|
||||
|
@ -136,7 +137,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationBinaryMerging() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationBinaryMerging(Nd4jBackend backend) {
|
||||
int nOut = 4;
|
||||
int[] shape1 = {30, nOut};
|
||||
int[] shape2 = {50, nOut};
|
||||
|
@ -163,7 +166,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationBinaryPerOutputMasking() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationBinaryPerOutputMasking(Nd4jBackend backend) {
|
||||
|
||||
//Provide a mask array: "ignore" the masked steps
|
||||
|
||||
|
@ -172,7 +177,7 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
|
|||
INDArray labels = Nd4j.create(new double[][] {{1, 1, 1}, {0, 0, 0}, {1, 1, 1}, {0, 1, 1}, {1, 0, 1}});
|
||||
|
||||
INDArray predicted = Nd4j.create(new double[][] {{0.9, 0.9, 0.9}, {0.7, 0.7, 0.7}, {0.6, 0.6, 0.6},
|
||||
{0.4, 0.4, 0.4}, {0.1, 0.1, 0.1}});
|
||||
{0.4, 0.4, 0.4}, {0.1, 0.1, 0.1}});
|
||||
|
||||
//Correct?
|
||||
// Y Y m
|
||||
|
@ -206,7 +211,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testTimeSeriesEval() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTimeSeriesEval(Nd4jBackend backend) {
|
||||
|
||||
int[] shape = {2, 4, 3};
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -230,12 +237,14 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationBinaryWithROC() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationBinaryWithROC(Nd4jBackend backend) {
|
||||
//Simple test for nested ROCBinary in EvaluationBinary
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray l1 = Nd4j.getExecutioner()
|
||||
.exec(new BernoulliDistribution(Nd4j.createUninitialized(new int[] {50, 4}), 0.5));
|
||||
.exec(new BernoulliDistribution(Nd4j.createUninitialized(new int[] {50, 4}), 0.5));
|
||||
INDArray p1 = Nd4j.rand(50, 4);
|
||||
|
||||
EvaluationBinary eb = new EvaluationBinary(4, 30);
|
||||
|
@ -247,7 +256,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testEvaluationBinary3d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationBinary3d(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||
|
||||
|
@ -281,7 +292,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationBinary4d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationBinary4d(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
|
||||
|
@ -315,7 +328,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationBinary3dMasking() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationBinary3dMasking(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||
|
||||
|
@ -376,7 +391,9 @@ public class EvaluationBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationBinary4dMasking() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationBinary4dMasking(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
|
||||
|
|
|
@ -21,8 +21,10 @@
|
|||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -39,19 +41,18 @@ import java.util.Random;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class EvaluationCalibrationTest extends BaseNd4jTest {
|
||||
public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public EvaluationCalibrationTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
public char ordering () {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReliabilityDiagram() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReliabilityDiagram (Nd4jBackend backend) {
|
||||
|
||||
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
||||
EvaluationCalibration first = null;
|
||||
|
@ -142,8 +143,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLabelAndPredictionCounts() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLabelAndPredictionCounts (Nd4jBackend backend) {
|
||||
|
||||
int minibatch = 50;
|
||||
int nClasses = 3;
|
||||
|
@ -170,8 +173,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
|
|||
assertArrayEquals(expPredictionCount, ec.getPredictionCountsEachClass());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResidualPlots() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testResidualPlots (Nd4jBackend backend) {
|
||||
|
||||
int minibatch = 50;
|
||||
int nClasses = 3;
|
||||
|
@ -271,7 +276,9 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSegmentation(){
|
||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -365,8 +372,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationCalibration3d() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationCalibration3d (Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||
|
||||
|
@ -397,8 +406,10 @@ public class EvaluationCalibrationTest extends BaseNd4jTest {
|
|||
assertEquals(e2d.stats(), e3d.stats());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEvaluationCalibration3dMasking() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvaluationCalibration3dMasking (Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ package org.nd4j.evaluation;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.EvaluationBinary;
|
||||
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||
|
@ -30,17 +32,14 @@ import org.nd4j.evaluation.classification.ROC;
|
|||
import org.nd4j.evaluation.classification.ROCBinary;
|
||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
public class NewInstanceTest extends BaseNd4jTest {
|
||||
public class NewInstanceTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public NewInstanceTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -48,7 +47,9 @@ public class NewInstanceTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNewInstances() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNewInstances(Nd4jBackend backend) {
|
||||
boolean print = true;
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
|
|
@ -21,10 +21,12 @@
|
|||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.classification.ROC;
|
||||
import org.nd4j.evaluation.classification.ROCBinary;
|
||||
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -39,19 +41,17 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class ROCBinaryTest extends BaseNd4jTest {
|
||||
|
||||
public ROCBinaryTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testROCBinary() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testROCBinary(Nd4jBackend backend) {
|
||||
//Compare ROCBinary to ROC class
|
||||
|
||||
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
||||
|
@ -145,8 +145,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRocBinaryMerging() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRocBinaryMerging(Nd4jBackend backend) {
|
||||
for (int nSteps : new int[]{30, 0}) { //0 == exact
|
||||
int nOut = 4;
|
||||
int[] shape1 = {30, nOut};
|
||||
|
@ -175,8 +177,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testROCBinaryPerOutputMasking() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
|
||||
|
||||
for (int nSteps : new int[]{30, 0}) { //0 == exact
|
||||
|
||||
|
@ -215,8 +219,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
|
||||
@Test
|
||||
public void testROCBinary3d() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testROCBinary3d(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||
|
||||
|
@ -249,8 +255,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testROCBinary4d() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testROCBinary4d(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
|
||||
|
@ -283,8 +291,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testROCBinary3dMasking() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testROCBinary3dMasking(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||
|
||||
|
@ -344,8 +354,10 @@ public class ROCBinaryTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testROCBinary4dMasking() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testROCBinary4dMasking(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
|
||||
|
|
|
@ -21,12 +21,14 @@
|
|||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.classification.ROC;
|
||||
import org.nd4j.evaluation.classification.ROCBinary;
|
||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
|
||||
import org.nd4j.evaluation.curves.RocCurve;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
@ -39,11 +41,8 @@ import java.util.*;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class ROCTest extends BaseNd4jTest {
|
||||
public class ROCTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public ROCTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -83,8 +82,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
expFPR.put(10 / 10.0, 0.0 / totalNegatives);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRocBasic() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRocBasic(Nd4jBackend backend) {
|
||||
//2 outputs here - probability distribution over classes (softmax)
|
||||
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
||||
|
@ -126,8 +127,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
assertEquals(1.0, auc, 1e-6);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRocBasicSingleClass() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRocBasicSingleClass(Nd4jBackend backend) {
|
||||
//1 output here - single probability value (sigmoid)
|
||||
|
||||
//add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||
|
@ -164,8 +167,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testRoc() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRoc(Nd4jBackend backend) {
|
||||
//Previous tests allowed for a perfect classifier with right threshold...
|
||||
|
||||
INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});
|
||||
|
@ -249,8 +254,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testRocTimeSeriesNoMasking() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
|
||||
//Same as first test...
|
||||
|
||||
//2 outputs here - probability distribution over classes (softmax)
|
||||
|
@ -296,8 +303,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRocTimeSeriesMasking() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
|
||||
//2 outputs here - probability distribution over classes (softmax)
|
||||
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
||||
|
@ -346,8 +355,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
|
||||
@Test
|
||||
public void testCompareRocAndRocMultiClass() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
//For 2 class case: ROC and Multi-class ROC should be the same...
|
||||
|
@ -376,8 +387,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCompare2Vs3Classes() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCompare2Vs3Classes(Nd4jBackend backend) {
|
||||
|
||||
//ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together...
|
||||
//Both methods implement one vs. all ROC/AUC in different ways
|
||||
|
@ -425,8 +438,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testROCMerging() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testROCMerging(Nd4jBackend backend) {
|
||||
int nArrays = 10;
|
||||
int minibatch = 64;
|
||||
int nROCs = 3;
|
||||
|
@ -470,8 +485,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testROCMerging2() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testROCMerging2(Nd4jBackend backend) {
|
||||
int nArrays = 10;
|
||||
int minibatch = 64;
|
||||
int exactAllocBlockSize = 10;
|
||||
|
@ -515,8 +532,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testROCMultiMerging() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testROCMultiMerging(Nd4jBackend backend) {
|
||||
|
||||
int nArrays = 10;
|
||||
int minibatch = 64;
|
||||
|
@ -563,8 +582,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAUCPrecisionRecall() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAUCPrecisionRecall(Nd4jBackend backend) {
|
||||
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
|
||||
//at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0
|
||||
//at threshold 0.25 to 0.33: tp=2, fp=0, fn=0, tn=1 prec=2/2=1, recall=2/2=1
|
||||
|
@ -610,8 +631,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testRocAucExact() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRocAucExact(Nd4jBackend backend) {
|
||||
|
||||
//Check the implementation vs. Scikitlearn
|
||||
/*
|
||||
|
@ -773,8 +796,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void rocExactEdgeCaseReallocation() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
|
||||
|
||||
//Set reallocation block size to say 20, but then evaluate a 100-length array
|
||||
|
||||
|
@ -785,8 +810,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testPrecisionRecallCurveGetPointMethods() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
|
||||
double[] threshold = new double[101];
|
||||
double[] precision = threshold;
|
||||
double[] recall = new double[101];
|
||||
|
@ -821,8 +848,10 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPrecisionRecallCurveConfusion() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
|
||||
//Sanity check: values calculated from the confusion matrix should match the PR curve values
|
||||
|
||||
for (boolean removeRedundantPts : new boolean[] {true, false}) {
|
||||
|
@ -860,7 +889,9 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRocMerge(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -904,7 +935,9 @@ public class ROCTest extends BaseNd4jTest {
|
|||
assertEquals(auprc, auprcAct, 1e-6);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRocMultiMerge(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -953,7 +986,9 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRocBinaryMerge(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -998,7 +1033,9 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSegmentationBinary(){
|
||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -1088,7 +1125,9 @@ public class ROCTest extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSegmentation(){
|
||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
|
|
@ -21,9 +21,11 @@
|
|||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -40,11 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
|
|||
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
|
||||
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
||||
|
||||
public class RegressionEvalTest extends BaseNd4jTest {
|
||||
public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
public RegressionEvalTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -52,7 +51,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test()
|
||||
public void testEvalParameters() {
|
||||
public void testEvalParameters(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
int specCols = 5;
|
||||
INDArray labels = Nd4j.ones(3);
|
||||
|
@ -65,7 +64,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPerfectPredictions() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPerfectPredictions(Nd4jBackend backend) {
|
||||
|
||||
int nCols = 5;
|
||||
int nTestArrays = 100;
|
||||
|
@ -92,7 +93,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testKnownValues() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testKnownValues(Nd4jBackend backend) {
|
||||
|
||||
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
||||
RegressionEvaluation first = null;
|
||||
|
@ -148,7 +151,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testRegressionEvaluationMerging() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRegressionEvaluationMerging(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
int nRows = 20;
|
||||
|
@ -189,7 +194,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRegressionEvalPerOutputMasking() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRegressionEvalPerOutputMasking(Nd4jBackend backend) {
|
||||
|
||||
INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}});
|
||||
|
||||
|
@ -216,6 +223,8 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRegressionEvalTimeSeriesSplit(){
|
||||
|
||||
INDArray out1 = Nd4j.rand(new int[]{3, 5, 20});
|
||||
|
@ -238,7 +247,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRegressionEval3d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRegressionEval3d(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||
|
||||
|
@ -270,7 +281,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRegressionEval4d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRegressionEval4d(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
|
||||
|
@ -302,7 +315,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRegressionEval3dMasking() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRegressionEval3dMasking(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||
|
||||
|
@ -361,7 +376,9 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRegressionEval4dMasking() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRegressionEval4dMasking(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
|
||||
|
|
|
@ -22,10 +22,12 @@ package org.nd4j.evaluation;
|
|||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
|
@ -34,11 +36,8 @@ import java.nio.charset.StandardCharsets;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class TestLegacyJsonLoading extends BaseNd4jTest {
|
||||
public class TestLegacyJsonLoading extends BaseNd4jTestWithBackends {
|
||||
|
||||
public TestLegacyJsonLoading(Nd4jBackend b){
|
||||
super(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
|
@ -46,7 +45,9 @@ public class TestLegacyJsonLoading extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEvalLegacyFormat() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEvalLegacyFormat(Nd4jBackend backend) throws Exception {
|
||||
|
||||
File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile();
|
||||
String s = FileUtils.readFileToString(f, StandardCharsets.UTF_8);
|
||||
|
|
|
@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -38,17 +39,14 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
public class AveragingTests extends BaseNd4jTest {
|
||||
|
||||
public class AveragingTests extends BaseNd4jTestWithBackends {
|
||||
private final int THREADS = 16;
|
||||
private final int LENGTH = 51200 * 4;
|
||||
|
||||
DataType initialType;
|
||||
DataType initialType = Nd4j.dataType();
|
||||
|
||||
|
||||
public AveragingTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
this.initialType = Nd4j.dataType();
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() {
|
||||
|
@ -63,7 +61,9 @@ public class AveragingTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSingleDeviceAveraging1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSingleDeviceAveraging1(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0);
|
||||
INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0);
|
||||
INDArray array3 = Nd4j.valueArrayOf(LENGTH, 3.0);
|
||||
|
@ -110,7 +110,9 @@ public class AveragingTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSingleDeviceAveraging2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSingleDeviceAveraging2(Nd4jBackend backend) {
|
||||
INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH);
|
||||
List<INDArray> arrays = new ArrayList<>();
|
||||
for (int i = 0; i < THREADS; i++)
|
||||
|
@ -127,7 +129,9 @@ public class AveragingTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testAccumulation1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAccumulation1(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.create(100).assign(1.0);
|
||||
INDArray array2 = Nd4j.create(100).assign(2.0);
|
||||
INDArray array3 = Nd4j.create(100).assign(3.0);
|
||||
|
@ -140,7 +144,9 @@ public class AveragingTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testAccumulation2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAccumulation2(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.create(100).assign(1.0);
|
||||
INDArray array2 = Nd4j.create(100).assign(2.0);
|
||||
INDArray array3 = Nd4j.create(100).assign(3.0);
|
||||
|
@ -155,7 +161,9 @@ public class AveragingTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testAccumulation3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAccumulation3(Nd4jBackend backend) {
|
||||
// we want to ensure that cuda backend is able to launch this op on cpu
|
||||
Nd4j.getAffinityManager().allowCrossDeviceAccess(false);
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ package org.nd4j.linalg;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -34,15 +35,14 @@ import java.io.*;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
|
||||
@Slf4j
|
||||
public class DataTypeTest extends BaseNd4jTest {
|
||||
public DataTypeTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class DataTypeTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
public void testDataTypes() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDataTypes(Nd4jBackend backend) throws Exception {
|
||||
for (val type : DataType.values()) {
|
||||
if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type))
|
||||
continue;
|
||||
|
|
|
@ -21,20 +21,17 @@
|
|||
package org.nd4j.linalg;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class InputValidationTests extends BaseNd4jTest {
|
||||
|
||||
public InputValidationTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class InputValidationTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -45,7 +42,9 @@ public class InputValidationTests extends BaseNd4jTest {
|
|||
///////////////////// Broadcast Tests ///////////////////////
|
||||
|
||||
@Test
|
||||
public void testInvalidColVectorOp1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInvalidColVectorOp1(Nd4jBackend backend) {
|
||||
INDArray first = Nd4j.create(10, 10);
|
||||
INDArray col = Nd4j.create(5, 1);
|
||||
try {
|
||||
|
@ -57,7 +56,9 @@ public class InputValidationTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidColVectorOp2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInvalidColVectorOp2(Nd4jBackend backend) {
|
||||
INDArray first = Nd4j.create(10, 10);
|
||||
INDArray col = Nd4j.create(5, 1);
|
||||
try {
|
||||
|
@ -69,7 +70,9 @@ public class InputValidationTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidRowVectorOp1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInvalidRowVectorOp1(Nd4jBackend backend) {
|
||||
INDArray first = Nd4j.create(10, 10);
|
||||
INDArray row = Nd4j.create(1, 5);
|
||||
try {
|
||||
|
@ -81,7 +84,9 @@ public class InputValidationTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidRowVectorOp2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInvalidRowVectorOp2(Nd4jBackend backend) {
|
||||
INDArray first = Nd4j.create(10, 10);
|
||||
INDArray row = Nd4j.create(1, 5);
|
||||
try {
|
||||
|
|
|
@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.apache.commons.lang3.RandomUtils;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||
|
@ -47,14 +48,13 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
public class LoneTest extends BaseNd4jTest {
|
||||
public LoneTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
public class LoneTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
public void testSoftmaxStability() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSoftmaxStability(Nd4jBackend backend) {
|
||||
INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose();
|
||||
// System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
|
||||
INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1);
|
||||
|
@ -68,7 +68,9 @@ public class LoneTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testFlattenedView() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testFlattenedView(Nd4jBackend backend) {
|
||||
int rows = 8;
|
||||
int cols = 8;
|
||||
int dim2 = 4;
|
||||
|
@ -104,7 +106,9 @@ public class LoneTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testIndexingColVec() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIndexingColVec(Nd4jBackend backend) {
|
||||
int elements = 5;
|
||||
INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements);
|
||||
INDArray colVector = rowVector.transpose();
|
||||
|
@ -123,7 +127,9 @@ public class LoneTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void concatScalarVectorIssue() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void concatScalarVectorIssue(Nd4jBackend backend) {
|
||||
//A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars
|
||||
INDArray arr1 = Nd4j.create(1, 1);
|
||||
INDArray arr2 = Nd4j.create(1, 8);
|
||||
|
@ -133,7 +139,9 @@ public class LoneTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void reshapeTensorMmul() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void reshapeTensorMmul(Nd4jBackend backend) {
|
||||
INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2);
|
||||
INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2);
|
||||
int[][] axes = new int[2][];
|
||||
|
@ -145,7 +153,9 @@ public class LoneTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void maskWhenMerge() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void maskWhenMerge(Nd4jBackend backend) {
|
||||
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
|
||||
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
|
||||
List<DataSet> dataSetList = new ArrayList<DataSet>();
|
||||
|
@ -160,7 +170,9 @@ public class LoneTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRelu() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRelu(Nd4jBackend backend) {
|
||||
INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
|
||||
INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
|
||||
INDArray b = Nd4j.getExecutioner().exec(new Tanh(aA));
|
||||
|
@ -172,7 +184,7 @@ public class LoneTest extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
//broken at a threshold
|
||||
public void testArgMax() {
|
||||
public void testArgMax(Nd4jBackend backend) {
|
||||
int max = 63;
|
||||
INDArray A = Nd4j.linspace(1, max, max).reshape(1, max);
|
||||
int currentArgMax = Nd4j.argMax(A).getInt(0);
|
||||
|
@ -186,7 +198,9 @@ public class LoneTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRPF() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRPF(Nd4jBackend backend) {
|
||||
val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3);
|
||||
|
||||
log.info("--------");
|
||||
|
@ -199,7 +213,9 @@ public class LoneTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testConcat3D_Vstack_C() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConcat3D_Vstack_C(Nd4jBackend backend) {
|
||||
val shape = new long[]{1, 1000, 20};
|
||||
|
||||
List<INDArray> cArrays = new ArrayList<>();
|
||||
|
@ -229,7 +245,9 @@ public class LoneTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testGetRow1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetRow1(Nd4jBackend backend) {
|
||||
INDArray array = Nd4j.create(10000, 10000);
|
||||
|
||||
//Thread.sleep(10000);
|
||||
|
@ -256,7 +274,7 @@ public class LoneTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test()
|
||||
public void checkIllegalElementOps() {
|
||||
public void checkIllegalElementOps(Nd4jBackend backend) {
|
||||
assertThrows(Exception.class,() -> {
|
||||
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
|
||||
INDArray B = A.dup().reshape(2, 2, 5);
|
||||
|
@ -268,7 +286,9 @@ public class LoneTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void checkSliceofSlice() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void checkSliceofSlice(Nd4jBackend backend) {
|
||||
/*
|
||||
Issue 1: Slice of slice with c order and f order views are not equal
|
||||
|
||||
|
@ -308,7 +328,9 @@ public class LoneTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void checkWithReshape() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void checkWithReshape(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.create(1, 3);
|
||||
INDArray reshaped = arr.reshape('f', 3, 1);
|
||||
for (int i=0;i<reshaped.length();i++) {
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
package org.nd4j.linalg;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -28,11 +30,8 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class MmulBug extends BaseNd4jTest {
|
||||
public class MmulBug extends BaseNd4jTestWithBackends {
|
||||
|
||||
public MmulBug(Nd4jBackend b){
|
||||
super(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
|
@ -40,7 +39,9 @@ public class MmulBug extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void simpleTest() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void simpleTest(Nd4jBackend backend) {
|
||||
INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}});
|
||||
|
||||
m1 = m1.reshape(2, 2);
|
||||
|
|
|
@ -25,8 +25,9 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -58,19 +59,14 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
|
||||
@Slf4j
|
||||
public class NDArrayTestsFortran extends BaseNd4jTest {
|
||||
|
||||
|
||||
public NDArrayTestsFortran(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
|
||||
public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
public void testScalarOps() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testScalarOps(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
|
||||
assertEquals(27d, n.length(), 1e-1);
|
||||
n.addi(Nd4j.scalar(1d));
|
||||
|
@ -88,7 +84,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testColumnMmul() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testColumnMmul(Nd4jBackend backend) {
|
||||
DataBuffer data = Nd4j.linspace(1, 10, 18, DataType.FLOAT).data();
|
||||
INDArray x2 = Nd4j.create(data, new long[] {2, 3, 3});
|
||||
data = Nd4j.linspace(1, 12, 9, DataType.FLOAT).data();
|
||||
|
@ -119,7 +117,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testRowVectorGemm() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRowVectorGemm(Nd4jBackend backend) {
|
||||
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE);
|
||||
INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE);
|
||||
INDArray result = linspace.mmul(other);
|
||||
|
@ -130,13 +130,17 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testRepmat() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRepmat(Nd4jBackend backend) {
|
||||
INDArray rowVector = Nd4j.create(1, 4);
|
||||
INDArray repmat = rowVector.repmat(4, 4);
|
||||
assertTrue(Arrays.equals(new long[] {4, 16}, repmat.shape()));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReadWrite() throws Exception {
|
||||
INDArray write = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
|
@ -152,6 +156,8 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReadWriteDouble() throws Exception {
|
||||
INDArray write = Nd4j.linspace(1, 4, 4, DataType.FLOAT);
|
||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
|
@ -168,18 +174,17 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMultiThreading() throws Exception {
|
||||
ExecutorService ex = ExecutorServiceProvider.getExecutorService();
|
||||
|
||||
List<Future<?>> list = new ArrayList<>(100);
|
||||
for (int i = 0; i < 100; i++) {
|
||||
Future<?> future = ex.submit(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
INDArray dot = Nd4j.linspace(1, 8, 8, DataType.DOUBLE);
|
||||
Future<?> future = ex.submit(() -> {
|
||||
INDArray dot = Nd4j.linspace(1, 8, 8, DataType.DOUBLE);
|
||||
// System.out.println(Transforms.sigmoid(dot));
|
||||
Transforms.sigmoid(dot);
|
||||
}
|
||||
Transforms.sigmoid(dot);
|
||||
});
|
||||
list.add(future);
|
||||
}
|
||||
|
@ -191,7 +196,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testBroadcastingGenerated() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBroadcastingGenerated(Nd4jBackend backend) {
|
||||
int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10);
|
||||
List<List<Pair<INDArray, String>>> broadCastList = new ArrayList<>(broadcastShape.length);
|
||||
for (int[] shape : broadcastShape) {
|
||||
|
@ -206,7 +213,7 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
INDArray inputArrBroadcast = val.getFirst();
|
||||
val destShape = NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7);
|
||||
INDArray output = inputArrBroadcast
|
||||
.broadcast(NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7));
|
||||
.broadcast(NDArrayCreationUtil.broadcastToShape(inputArrBroadcast.shape(), 7));
|
||||
assertArrayEquals(destShape, output.shape());
|
||||
}
|
||||
}
|
||||
|
@ -216,7 +223,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testBroadCasting() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBroadCasting(Nd4jBackend backend) {
|
||||
INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE);
|
||||
INDArray ret = first.broadcast(3, 4);
|
||||
INDArray testRet = Nd4j.create(new double[][] {{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}});
|
||||
|
@ -229,14 +238,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testOneTensor() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOneTensor(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1);
|
||||
INDArray matrixToBroadcast = Nd4j.ones(1, 1);
|
||||
assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSortWithIndicesDescending() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSortWithIndicesDescending(Nd4jBackend backend) {
|
||||
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
//indices,data
|
||||
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, false);
|
||||
|
@ -247,7 +260,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSortDeadlock() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSortDeadlock(Nd4jBackend backend) {
|
||||
val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768);
|
||||
|
||||
val sorted = Nd4j.sort(toSort.dup(), 1, false);
|
||||
|
@ -255,7 +270,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSortWithIndices() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSortWithIndices(Nd4jBackend backend) {
|
||||
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
//indices,data
|
||||
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, true);
|
||||
|
@ -266,14 +283,18 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNd4jSortScalar() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNd4jSortScalar(Nd4jBackend backend) {
|
||||
INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1);
|
||||
INDArray sorted = Nd4j.sort(linspace, 1, false);
|
||||
// System.out.println(sorted);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSwapAxesFortranOrder() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSwapAxesFortranOrder(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE);
|
||||
for (int i = 0; i < n.slices(); i++) {
|
||||
INDArray nSlice = n.slice(i);
|
||||
|
@ -292,7 +313,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testDimShuffle() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDimShuffle(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false});
|
||||
assertTrue(Arrays.equals(new long[] {2, 1, 2}, twoOneTwo.shape()));
|
||||
|
@ -303,7 +326,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetVsGetScalar() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetVsGetScalar(Nd4jBackend backend) {
|
||||
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
float element = a.getFloat(0, 1);
|
||||
double element2 = a.getDouble(0, 1);
|
||||
|
@ -316,7 +341,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDivide() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDivide(Nd4jBackend backend) {
|
||||
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
|
||||
INDArray div = two.div(two);
|
||||
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage());
|
||||
|
@ -330,7 +357,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSigmoid() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSigmoid(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
|
||||
INDArray sigmoid = Transforms.sigmoid(n, false);
|
||||
|
@ -339,7 +368,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNeg() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNeg(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
||||
INDArray neg = Transforms.neg(n);
|
||||
|
@ -349,7 +380,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testCosineSim() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCosineSim(Nd4jBackend backend) {
|
||||
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||
double sim = Transforms.cosineSim(vec1, vec2);
|
||||
|
@ -364,7 +397,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testExp() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testExp(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||
INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f});
|
||||
INDArray exped = Transforms.exp(n);
|
||||
|
@ -374,7 +409,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testScalar() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testScalar(Nd4jBackend backend) {
|
||||
INDArray a = Nd4j.scalar(1.0f);
|
||||
assertEquals(true, a.isScalar());
|
||||
|
||||
|
@ -386,7 +423,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testWrap() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testWrap(Nd4jBackend backend) {
|
||||
int[] shape = {2, 4};
|
||||
INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]);
|
||||
INDArray n = d;
|
||||
|
@ -411,7 +450,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetRowFortran() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetRowFortran(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2});
|
||||
INDArray column = Nd4j.create(new float[] {1, 3});
|
||||
INDArray column2 = Nd4j.create(new float[] {2, 4});
|
||||
|
@ -424,7 +465,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetColumnFortran() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetColumnFortran(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
|
||||
INDArray column = Nd4j.create(new double[] {1, 2});
|
||||
INDArray column2 = Nd4j.create(new double[] {3, 4});
|
||||
|
@ -438,7 +481,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testGetColumns() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetColumns(Nd4jBackend backend) {
|
||||
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE);
|
||||
// log.info("Original: {}", matrix);
|
||||
INDArray matrixGet = matrix.getColumns(1, 2);
|
||||
|
@ -452,7 +497,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testVectorInit() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testVectorInit(Nd4jBackend backend) {
|
||||
DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data();
|
||||
INDArray arr = Nd4j.create(data, new long[] {1, 4});
|
||||
assertEquals(true, arr.isRowVector());
|
||||
|
@ -465,7 +512,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testAssignOffset() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAssignOffset(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.ones(5, 5);
|
||||
INDArray row = arr.slice(1);
|
||||
row.assign(1);
|
||||
|
@ -473,7 +522,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testColumns() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testColumns(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE);
|
||||
INDArray column = Nd4j.create(new double[] {1, 2, 3});
|
||||
arr.putColumn(0, column);
|
||||
|
@ -511,7 +562,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testPutRow() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPutRow(Nd4jBackend backend) {
|
||||
INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
INDArray n = d.dup();
|
||||
|
||||
|
@ -570,7 +623,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testInplaceTranspose() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInplaceTranspose(Nd4jBackend backend) {
|
||||
INDArray test = Nd4j.rand(3, 4);
|
||||
INDArray orig = test.dup();
|
||||
INDArray transposei = test.transposei();
|
||||
|
@ -585,7 +640,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testMmulF() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMmulF(Nd4jBackend backend) {
|
||||
|
||||
DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data();
|
||||
INDArray n = Nd4j.create(data, new long[] {1, 10});
|
||||
|
@ -603,7 +660,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testRowsColumns() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRowsColumns(Nd4jBackend backend) {
|
||||
DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data();
|
||||
INDArray rows = Nd4j.create(data, new long[] {2, 3});
|
||||
assertEquals(2, rows.rows());
|
||||
|
@ -619,7 +678,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testTranspose() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTranspose(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4});
|
||||
INDArray transpose = n.transpose();
|
||||
assertEquals(n.length(), transpose.length());
|
||||
|
@ -647,7 +708,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testAddMatrix() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAddMatrix(Nd4jBackend backend) {
|
||||
INDArray five = Nd4j.ones(5);
|
||||
five.addi(five.dup());
|
||||
INDArray twos = Nd4j.valueArrayOf(5, 2);
|
||||
|
@ -658,7 +721,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testMMul() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMMul(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
||||
|
||||
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
||||
|
@ -669,7 +734,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPutSlice() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPutSlice(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3);
|
||||
INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3);
|
||||
Nd4j.exec(new PrintVariable(newSlice));
|
||||
|
@ -680,7 +747,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRowVectorMultipleIndices() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
|
||||
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
|
||||
linear.putScalar(new long[] {0, 1}, 1);
|
||||
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage());
|
||||
|
@ -689,7 +758,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testDim1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDim1(Nd4jBackend backend) {
|
||||
INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1);
|
||||
INDArray same = sum.dup();
|
||||
assertEquals(same.sum(1), sum.reshape(2));
|
||||
|
@ -697,7 +768,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testEps() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEps(Nd4jBackend backend) {
|
||||
val ones = Nd4j.ones(5);
|
||||
val res = Nd4j.createUninitialized(DataType.BOOL, 5);
|
||||
assertTrue(Nd4j.getExecutioner().exec(new Eps(ones, ones, res)).all());
|
||||
|
@ -705,7 +778,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testLogDouble() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testLogDouble(Nd4jBackend backend) {
|
||||
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE);
|
||||
INDArray log = Transforms.log(linspace);
|
||||
INDArray assertion = Nd4j.create(new double[] {0, 0.6931471805599453, 1.0986122886681098, 1.3862943611198906, 1.6094379124341005, 1.791759469228055});
|
||||
|
@ -713,28 +788,36 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testVectorSum() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testVectorSum(Nd4jBackend backend) {
|
||||
INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorSum2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testVectorSum2(Nd4jBackend backend) {
|
||||
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorSum3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testVectorSum3(Nd4jBackend backend) {
|
||||
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||
INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||
assertEquals(lin, lin2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSmallSum() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSmallSum(Nd4jBackend backend) {
|
||||
INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007});
|
||||
base.addi(1e-12);
|
||||
INDArray assertion = Nd4j.create(new double[] {5.84333433, 3.054001});
|
||||
|
@ -745,7 +828,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testPermute() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPermute(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4});
|
||||
INDArray transpose = n.transpose();
|
||||
INDArray permute = n.permute(1, 0);
|
||||
|
@ -774,7 +859,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testAppendBias() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAppendBias(Nd4jBackend backend) {
|
||||
INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose();
|
||||
INDArray test = Nd4j.appendBias(rand);
|
||||
INDArray assertion = Nd4j.toFlattened(rand, Nd4j.scalar(DataType.DOUBLE, 1.0)).reshape(-1, 1);
|
||||
|
@ -782,7 +869,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRand() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRand(Nd4jBackend backend) {
|
||||
INDArray rand = Nd4j.randn(5, 5);
|
||||
Nd4j.getDistributions().createUniform(0.4, 4).sample(5);
|
||||
Nd4j.getDistributions().createNormal(1, 5).sample(10);
|
||||
|
@ -794,7 +883,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testIdentity() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIdentity(Nd4jBackend backend) {
|
||||
INDArray eye = Nd4j.eye(5);
|
||||
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
|
||||
eye = Nd4j.eye(5);
|
||||
|
@ -805,7 +896,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testColumnVectorOpsFortran() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testColumnVectorOpsFortran(Nd4jBackend backend) {
|
||||
INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||
INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1});
|
||||
twoByTwo.addiColumnVector(toAdd);
|
||||
|
@ -816,7 +909,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testRSubi() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRSubi(Nd4jBackend backend) {
|
||||
INDArray n2 = Nd4j.ones(2);
|
||||
INDArray n2Assertion = Nd4j.zeros(2);
|
||||
INDArray nRsubi = n2.rsubi(1);
|
||||
|
@ -826,7 +921,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testAssign() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAssign(Nd4jBackend backend) {
|
||||
INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
||||
vector.assign(1);
|
||||
assertEquals(Nd4j.ones(5).castTo(DataType.DOUBLE), vector);
|
||||
|
@ -843,7 +940,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAddScalar() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAddScalar(Nd4jBackend backend) {
|
||||
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
|
||||
INDArray rdiv = div.add(1);
|
||||
INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 5.0);
|
||||
|
@ -851,7 +950,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRdivScalar() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRdivScalar(Nd4jBackend backend) {
|
||||
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
|
||||
INDArray rdiv = div.rdiv(1);
|
||||
INDArray answer = Nd4j.valueArrayOf(new long[] {1, 4}, 0.25);
|
||||
|
@ -859,7 +960,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRDivi() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRDivi(Nd4jBackend backend) {
|
||||
INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0);
|
||||
INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5);
|
||||
INDArray nRsubi = n2.rdivi(2);
|
||||
|
@ -869,7 +972,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testNumVectorsAlongDimension() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNumVectorsAlongDimension(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2);
|
||||
assertEquals(12, arr.vectorsAlongDimension(2));
|
||||
}
|
||||
|
@ -877,7 +982,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testBroadCast() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBroadCast(Nd4jBackend backend) {
|
||||
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||
INDArray broadCasted = n.broadcast(5, 4);
|
||||
for (int i = 0; i < broadCasted.rows(); i++) {
|
||||
|
@ -899,7 +1006,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMatrix() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMatrix(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||
INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2});
|
||||
INDArray row = arr.getRow(0);
|
||||
|
@ -909,7 +1018,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPutRowGetRowOrdering() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPutRowGetRowOrdering(Nd4jBackend backend) {
|
||||
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
INDArray put = Nd4j.create(new double[] {5, 6});
|
||||
row1.putRow(1, put);
|
||||
|
@ -931,7 +1042,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSumWithRow1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSumWithRow1(Nd4jBackend backend) {
|
||||
//Works:
|
||||
INDArray array2d = Nd4j.ones(1, 10);
|
||||
array2d.sum(0); //OK
|
||||
|
@ -962,7 +1075,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSumWithRow2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSumWithRow2(Nd4jBackend backend) {
|
||||
//All sums in this method execute without exceptions.
|
||||
INDArray array3d = Nd4j.ones(2, 10, 10);
|
||||
array3d.sum(0);
|
||||
|
@ -985,7 +1100,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testPutRowFortran() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPutRowFortran(Nd4jBackend backend) {
|
||||
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE);
|
||||
INDArray put = Nd4j.create(new double[] {5, 6});
|
||||
row1.putRow(1, put);
|
||||
|
@ -998,7 +1115,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testElementWiseOps() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testElementWiseOps(Nd4jBackend backend) {
|
||||
INDArray n1 = Nd4j.scalar(1);
|
||||
INDArray n2 = Nd4j.scalar(2);
|
||||
INDArray nClone = n1.add(n2);
|
||||
|
@ -1021,7 +1140,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testRollAxis() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRollAxis(Nd4jBackend backend) {
|
||||
INDArray toRoll = Nd4j.ones(3, 4, 5, 6);
|
||||
assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape());
|
||||
val shape = Nd4j.rollAxis(toRoll, 3).shape();
|
||||
|
@ -1030,20 +1151,22 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testTensorDot() {
|
||||
public void testTensorDot(Nd4jBackend backend) {
|
||||
INDArray oneThroughSixty = Nd4j.arange(60).reshape('f', 3, 4, 5).castTo(DataType.DOUBLE);
|
||||
INDArray oneThroughTwentyFour = Nd4j.arange(24).reshape('f', 4, 3, 2).castTo(DataType.DOUBLE);
|
||||
INDArray result = Nd4j.tensorMmul(oneThroughSixty, oneThroughTwentyFour, new int[][] {{1, 0}, {0, 1}});
|
||||
assertArrayEquals(new long[] {5, 2}, result.shape());
|
||||
INDArray assertion = Nd4j.create(new double[][] {{440., 1232.}, {1232., 3752.}, {2024., 6272.}, {2816., 8792.},
|
||||
{3608., 11312.}});
|
||||
{3608., 11312.}});
|
||||
assertEquals(assertion, result);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testNegativeShape() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNegativeShape(Nd4jBackend backend) {
|
||||
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||
INDArray reshaped = linspace.reshape(-1, 2);
|
||||
assertArrayEquals(new long[] {2, 2}, reshaped.shape());
|
||||
|
@ -1055,7 +1178,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetColumnGetRow() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetColumnGetRow(Nd4jBackend backend) {
|
||||
INDArray row = Nd4j.ones(1, 5);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
INDArray col = row.getColumn(i);
|
||||
|
@ -1070,7 +1195,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDupAndDupWithOrder() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDupAndDupWithOrder(Nd4jBackend backend) {
|
||||
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
|
||||
int count = 0;
|
||||
for (Pair<INDArray, String> pair : testInputs) {
|
||||
|
@ -1092,7 +1219,9 @@ public class NDArrayTestsFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testToOffsetZeroCopy() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testToOffsetZeroCopy(Nd4jBackend backend) {
|
||||
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
|
||||
|
||||
int cnt = 0;
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -23,8 +23,9 @@ package org.nd4j.linalg;
|
|||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -42,18 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
|
||||
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class Nd4jTestsComparisonC extends BaseNd4jTest {
|
||||
|
||||
public class Nd4jTestsComparisonC extends BaseNd4jTestWithBackends {
|
||||
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonC.class);
|
||||
|
||||
public static final int SEED = 123;
|
||||
|
||||
DataType initialType;
|
||||
DataType initialType = Nd4j.dataType();
|
||||
|
||||
public Nd4jTestsComparisonC(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
this.initialType = Nd4j.dataType();
|
||||
}
|
||||
|
||||
|
||||
@BeforeEach
|
||||
|
@ -73,7 +70,9 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testGemmWithOpsCommonsMath() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
|
||||
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
|
||||
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
|
||||
|
@ -140,13 +139,13 @@ public class Nd4jTestsComparisonC extends BaseNd4jTest {
|
|||
|
||||
|
||||
private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first,
|
||||
Pair<INDArray, String> second) {
|
||||
Pair<INDArray, String> second) {
|
||||
return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")";
|
||||
}
|
||||
|
||||
private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha,
|
||||
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
|
||||
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
|
||||
return i + "," + j + " - gemm(tA=" + transposeA + ",tB=" + transposeB + ",alpha=" + alpha + ",beta=" + beta
|
||||
+ "). A=" + first.getSecond() + ", B=" + second.getSecond();
|
||||
+ "). A=" + first.getSecond() + ", B=" + second.getSecond();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,8 +25,9 @@ import org.apache.commons.math3.linear.RealMatrix;
|
|||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -43,18 +44,14 @@ import java.util.Random;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
||||
|
||||
public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
|
||||
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonFortran.class);
|
||||
|
||||
public static final int SEED = 123;
|
||||
|
||||
DataType initialType;
|
||||
DataType initialType = Nd4j.dataType();
|
||||
|
||||
public Nd4jTestsComparisonFortran(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
this.initialType = Nd4j.dataType();
|
||||
}
|
||||
|
||||
|
||||
@BeforeEach
|
||||
|
@ -75,7 +72,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCrash() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCrash(Nd4jBackend backend) {
|
||||
INDArray array3d = Nd4j.ones(1, 10, 10);
|
||||
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0);
|
||||
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 1);
|
||||
|
@ -85,7 +84,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMmulWithOpsCommonsMath() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMmulWithOpsCommonsMath(Nd4jBackend backend) {
|
||||
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
|
||||
|
||||
|
@ -100,7 +101,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemmWithOpsCommonsMath() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
|
||||
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
|
||||
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
|
||||
|
@ -156,7 +159,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemvApacheCommons() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemvApacheCommons(Nd4jBackend backend) {
|
||||
|
||||
int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8};
|
||||
int[] colsArr = new int[] {2, 1, 10, 2, 1, 10};
|
||||
|
@ -197,7 +202,7 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
|||
|
||||
assertArrayEquals(new long[] {rows, 1}, gemv.shape());
|
||||
assertArrayEquals(new int[] {rows, 1},
|
||||
new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()});
|
||||
new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()});
|
||||
|
||||
//Check entries:
|
||||
for (int r = 0; r < rows; r++) {
|
||||
|
@ -211,7 +216,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAddSubtractWithOpsCommonsMath() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAddSubtractWithOpsCommonsMath(Nd4jBackend backend) {
|
||||
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||
for (int i = 0; i < first.size(); i++) {
|
||||
|
@ -229,7 +236,9 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMulDivOnCheckUtilMatrices() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMulDivOnCheckUtilMatrices(Nd4jBackend backend) {
|
||||
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||
for (int i = 0; i < first.size(); i++) {
|
||||
|
@ -245,13 +254,13 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first,
|
||||
Pair<INDArray, String> second) {
|
||||
Pair<INDArray, String> second) {
|
||||
return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")";
|
||||
}
|
||||
|
||||
private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha,
|
||||
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
|
||||
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
|
||||
return i + "," + j + " - gemm(tA=" + transposeA + ",tB= " + transposeB + ",alpha=" + alpha + ",beta= " + beta
|
||||
+ "). A=" + first.getSecond() + ", B=" + second.getSecond();
|
||||
+ "). A=" + first.getSecond() + ", B=" + second.getSecond();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,8 +23,9 @@ package org.nd4j.linalg;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -36,18 +37,15 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
public class Nd4jTestsF extends BaseNd4jTest {
|
||||
|
||||
DataType initialType;
|
||||
public class Nd4jTestsF extends BaseNd4jTestWithBackends {
|
||||
|
||||
public Nd4jTestsF(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
this.initialType = Nd4j.dataType();
|
||||
}
|
||||
DataType initialType = Nd4j.dataType();
|
||||
|
||||
@Test
|
||||
public void testConcat3D_Vstack_F() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConcat3D_Vstack_F(Nd4jBackend backend) {
|
||||
//Nd4j.getExecutioner().enableVerboseMode(true);
|
||||
//Nd4j.getExecutioner().enableDebugMode(true);
|
||||
|
||||
|
@ -79,7 +77,9 @@ public class Nd4jTestsF extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSlice_1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSlice_1(Nd4jBackend backend) {
|
||||
val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1);
|
||||
val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1});
|
||||
val exp1 = Nd4j.create(new double[]{2, 4}, new int[] {2, 1});
|
||||
|
|
|
@ -22,8 +22,9 @@ package org.nd4j.linalg;
|
|||
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
@ -34,15 +35,13 @@ import java.util.*;
|
|||
import static junit.framework.TestCase.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class ShufflesTests extends BaseNd4jTest {
|
||||
|
||||
public ShufflesTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class ShufflesTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
public void testSimpleShuffle1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSimpleShuffle1(Nd4jBackend backend) {
|
||||
INDArray array = Nd4j.zeros(10, 10);
|
||||
for (int x = 0; x < 10; x++) {
|
||||
array.getRow(x).assign(x);
|
||||
|
@ -64,7 +63,9 @@ public class ShufflesTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSimpleShuffle2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSimpleShuffle2(Nd4jBackend backend) {
|
||||
INDArray array = Nd4j.zeros(10, 10);
|
||||
for (int x = 0; x < 10; x++) {
|
||||
array.getColumn(x).assign(x);
|
||||
|
@ -79,7 +80,9 @@ public class ShufflesTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSimpleShuffle3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSimpleShuffle3(Nd4jBackend backend) {
|
||||
INDArray array = Nd4j.zeros(11, 10);
|
||||
for (int x = 0; x < 11; x++) {
|
||||
array.getRow(x).assign(x);
|
||||
|
@ -95,7 +98,9 @@ public class ShufflesTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSymmetricShuffle1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSymmetricShuffle1(Nd4jBackend backend) {
|
||||
INDArray features = Nd4j.zeros(10, 10);
|
||||
INDArray labels = Nd4j.zeros(10, 3);
|
||||
for (int x = 0; x < 10; x++) {
|
||||
|
@ -133,7 +138,9 @@ public class ShufflesTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSymmetricShuffle2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSymmetricShuffle2(Nd4jBackend backend) {
|
||||
INDArray features = Nd4j.zeros(10, 10, 20);
|
||||
INDArray labels = Nd4j.zeros(10, 10, 3);
|
||||
|
||||
|
@ -171,7 +178,9 @@ public class ShufflesTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSymmetricShuffle3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSymmetricShuffle3(Nd4jBackend backend) {
|
||||
INDArray features = Nd4j.zeros(10, 10, 20);
|
||||
INDArray featuresMask = Nd4j.zeros(10, 20);
|
||||
INDArray labels = Nd4j.zeros(10, 10, 3);
|
||||
|
@ -236,7 +245,9 @@ public class ShufflesTests extends BaseNd4jTest {
|
|||
* @throws Exception
|
||||
*/
|
||||
@Test
|
||||
public void testHalfVectors1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testHalfVectors1(Nd4jBackend backend) {
|
||||
int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20);
|
||||
int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20);
|
||||
|
||||
|
@ -257,7 +268,9 @@ public class ShufflesTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testInterleavedVector1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInterleavedVector1(Nd4jBackend backend) {
|
||||
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20);
|
||||
int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20);
|
||||
|
||||
|
@ -278,7 +291,9 @@ public class ShufflesTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testInterleavedVector3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInterleavedVector3(Nd4jBackend backend) {
|
||||
for (int e = 0; e < 1000; e++) {
|
||||
int length = e + 256; //RandomUtils.nextInt(121, 2073);
|
||||
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(System.currentTimeMillis()), length);
|
||||
|
|
|
@ -24,8 +24,9 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.eigen.Eigen;
|
||||
|
@ -35,16 +36,11 @@ import org.nd4j.common.util.ArrayUtil;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
|
||||
@Slf4j
|
||||
public class TestEigen extends BaseNd4jTest {
|
||||
public class TestEigen extends BaseNd4jTestWithBackends {
|
||||
|
||||
protected DataType initialType;
|
||||
|
||||
public TestEigen(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
initialType = Nd4j.dataType();
|
||||
}
|
||||
protected DataType initialType = Nd4j.dataType();
|
||||
|
||||
@BeforeEach
|
||||
public void before() {
|
||||
|
@ -59,7 +55,9 @@ public class TestEigen extends BaseNd4jTest {
|
|||
// test of functions added by Luke Czapla
|
||||
// Compares solution of A x = L x to solution to A x = L B x when it is simple
|
||||
@Test
|
||||
public void test2Syev() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void test2Syev(Nd4jBackend backend) {
|
||||
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
||||
Nd4j.setDefaultDataTypes(dt, dt);
|
||||
|
||||
|
@ -78,7 +76,9 @@ public class TestEigen extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSyev() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSyev(Nd4jBackend backend) {
|
||||
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
||||
//log.info("Datatype: {}", dt);
|
||||
Nd4j.setDefaultDataTypes(dt, dt);
|
||||
|
|
|
@ -24,23 +24,23 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.common.util.ArrayUtil;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
|
||||
@Slf4j
|
||||
public class ToStringTest extends BaseNd4jTest {
|
||||
public ToStringTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class ToStringTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
public void testToString() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testToString(Nd4jBackend backend) throws Exception {
|
||||
assertEquals("[ 1, 2, 3]",
|
||||
Nd4j.createFromArray(1, 2, 3).toString());
|
||||
|
||||
|
@ -58,6 +58,8 @@ public class ToStringTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testToStringScalars(){
|
||||
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
|
||||
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};
|
||||
|
|
|
@ -22,9 +22,10 @@ package org.nd4j.linalg.activations;
|
|||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.activations.impl.ActivationCube;
|
||||
import org.nd4j.linalg.activations.impl.ActivationELU;
|
||||
import org.nd4j.linalg.activations.impl.ActivationGELU;
|
||||
|
@ -55,12 +56,9 @@ import java.util.List;
|
|||
import static junit.framework.TestCase.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestActivation extends BaseNd4jTest {
|
||||
|
||||
public TestActivation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class TestActivation extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -79,7 +77,9 @@ public class TestActivation extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRelu(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRelu(Nd4jBackend backend){
|
||||
|
||||
Double[] max = {null, 6.0, 2.5, 5.0};
|
||||
Double[] threshold = {0.0, 0.0, 0.75, 0.2};
|
||||
|
@ -131,30 +131,32 @@ public class TestActivation extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testJson() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testJson(Nd4jBackend backend) throws Exception {
|
||||
|
||||
IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25),
|
||||
new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(),
|
||||
new ActivationLReLU(0.25), new ActivationRationalTanh(), new ActivationReLU(),
|
||||
new ActivationRReLU(0.25, 0.5), new ActivationSigmoid(), new ActivationSoftmax(),
|
||||
new ActivationSoftPlus(), new ActivationSoftSign(), new ActivationTanH(), new ActivationGELU(), new ActivationGELU(true)};
|
||||
new ActivationHardSigmoid(), new ActivationHardTanH(), new ActivationIdentity(),
|
||||
new ActivationLReLU(0.25), new ActivationRationalTanh(), new ActivationReLU(),
|
||||
new ActivationRReLU(0.25, 0.5), new ActivationSigmoid(), new ActivationSoftmax(),
|
||||
new ActivationSoftPlus(), new ActivationSoftSign(), new ActivationTanH(), new ActivationGELU(), new ActivationGELU(true)};
|
||||
|
||||
String[][] expectedFields = new String[][] {{"@class"}, //Cube
|
||||
{"@class", "alpha"}, //ELU
|
||||
{"@class"}, //Hard sigmoid
|
||||
{"@class"}, //Hard TanH
|
||||
{"@class"}, //Identity
|
||||
{"@class", "alpha"}, //Leaky Relu
|
||||
{"@class"}, //rational tanh
|
||||
{"@class", "max", "negativeSlope", "threshold"}, //relu
|
||||
{"@class", "l", "u"}, //rrelu
|
||||
{"@class"}, //sigmoid
|
||||
{"@class"}, //Softmax
|
||||
{"@class"}, //Softplus
|
||||
{"@class"}, //Softsign
|
||||
{"@class"}, //Tanh
|
||||
{"@class", "precise"}, //GELU
|
||||
{"@class", "precise"} //GELU precise
|
||||
{"@class", "alpha"}, //ELU
|
||||
{"@class"}, //Hard sigmoid
|
||||
{"@class"}, //Hard TanH
|
||||
{"@class"}, //Identity
|
||||
{"@class", "alpha"}, //Leaky Relu
|
||||
{"@class"}, //rational tanh
|
||||
{"@class", "max", "negativeSlope", "threshold"}, //relu
|
||||
{"@class", "l", "u"}, //rrelu
|
||||
{"@class"}, //sigmoid
|
||||
{"@class"}, //Softmax
|
||||
{"@class"}, //Softplus
|
||||
{"@class"}, //Softsign
|
||||
{"@class"}, //Tanh
|
||||
{"@class", "precise"}, //GELU
|
||||
{"@class", "precise"} //GELU precise
|
||||
|
||||
};
|
||||
|
||||
|
@ -172,7 +174,7 @@ public class TestActivation extends BaseNd4jTest {
|
|||
String[] expFields = expectedFields[i];
|
||||
|
||||
String msg = activations[i].toString() + "\tExpected fields: " + Arrays.toString(expFields)
|
||||
+ "\tActual fields: " + actualFieldsByName;
|
||||
+ "\tActual fields: " + actualFieldsByName;
|
||||
assertEquals(expFields.length, actualFieldsByName.size(),msg);
|
||||
|
||||
for (String s : expFields) {
|
||||
|
|
|
@ -20,21 +20,20 @@
|
|||
package org.nd4j.linalg.api;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.factory.Environment;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
|
||||
public class TestBackend extends BaseNd4jTest {
|
||||
public class TestBackend extends BaseNd4jTestWithBackends {
|
||||
|
||||
public TestBackend(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void TestBuildInfo(){
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBuildInfo(Nd4jBackend backend){
|
||||
System.out.println("Backend build info: " + backend.buildInfo());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,26 +20,27 @@
|
|||
package org.nd4j.linalg.api;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.factory.Environment;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
|
||||
public class TestEnvironment extends BaseNd4jTest {
|
||||
public class TestEnvironment extends BaseNd4jTestWithBackends {
|
||||
|
||||
public TestEnvironment(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEnvironment(){
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEnvironment(Nd4jBackend backend){
|
||||
Environment e = Nd4j.getEnvironment();
|
||||
System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion());
|
||||
System.out.println("CPU: " + e.isCPU());
|
||||
|
|
|
@ -26,7 +26,9 @@ import org.bytedeco.javacpp.FloatPointer;
|
|||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -40,16 +42,12 @@ import java.util.Map;
|
|||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@Slf4j
|
||||
public class TestNDArrayCreation extends BaseNd4jTest {
|
||||
|
||||
|
||||
public TestNDArrayCreation(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
|
||||
public void testBufferCreation() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBufferCreation(Nd4jBackend backend) {
|
||||
DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2});
|
||||
Pointer pointer = dataBuffer.pointer();
|
||||
FloatPointer floatPointer = new FloatPointer(pointer);
|
||||
|
@ -69,6 +67,8 @@ public class TestNDArrayCreation extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
@Disabled
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCreateNpy() throws Exception {
|
||||
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile());
|
||||
assertEquals(2, arrCreate.size(0));
|
||||
|
@ -82,7 +82,9 @@ public class TestNDArrayCreation extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCreateNpz() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCreateNpz(Nd4jBackend backend) throws Exception {
|
||||
Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile());
|
||||
assertEquals(true, map.containsKey("x"));
|
||||
assertEquals(true, map.containsKey("y"));
|
||||
|
@ -100,8 +102,7 @@ public class TestNDArrayCreation extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
|
||||
public void testCreateNpy3() throws Exception {
|
||||
public void testCreateNpy3(Nd4jBackend backend) throws Exception {
|
||||
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
|
||||
assertEquals(8, arrCreate.length());
|
||||
assertEquals(3, arrCreate.rank());
|
||||
|
@ -113,7 +114,7 @@ public class TestNDArrayCreation extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
@Disabled // this is endless test
|
||||
public void testEndlessAllocation() {
|
||||
public void testEndlessAllocation(Nd4jBackend backend) {
|
||||
Nd4j.getEnvironment().setMaxSpecialMemory(1);
|
||||
while (true) {
|
||||
val arr = Nd4j.createUninitialized(DataType.FLOAT, 100000000);
|
||||
|
|
|
@ -21,24 +21,23 @@
|
|||
package org.nd4j.linalg.api;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.common.util.ArrayUtil;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
|
||||
public class TestNDArrayCreationUtil extends BaseNd4jTest {
|
||||
public class TestNDArrayCreationUtil extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
public TestNDArrayCreationUtil(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testShapes() {
|
||||
|
||||
long[] shape2d = {2, 3};
|
||||
|
|
|
@ -21,20 +21,21 @@
|
|||
package org.nd4j.linalg.api;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
public class TestNamespaces extends BaseNd4jTest {
|
||||
public class TestNamespaces extends BaseNd4jTestWithBackends {
|
||||
|
||||
public TestNamespaces(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBitwiseSimple(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBitwiseSimple(Nd4jBackend backend){
|
||||
|
||||
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
|
||||
INDArray y = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
|
||||
|
@ -50,7 +51,9 @@ public class TestNamespaces extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMathSimple(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testMathSimple(Nd4jBackend backend) {
|
||||
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1);
|
||||
INDArray abs = Nd4j.math.abs(x);
|
||||
// System.out.println(x);
|
||||
|
@ -65,7 +68,9 @@ public class TestNamespaces extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRandomSimple(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRandomSimple(Nd4jBackend backend){
|
||||
INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10);
|
||||
// System.out.println(normal);
|
||||
INDArray uniform = Nd4j.random.uniform(0, 1, DataType.FLOAT, 10);
|
||||
|
@ -73,7 +78,9 @@ public class TestNamespaces extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNeuralNetworkSimple(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNeuralNetworkSimple(Nd4jBackend backend){
|
||||
INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10));
|
||||
// System.out.println(out);
|
||||
INDArray out2 = Nd4j.nn.softmax(Nd4j.random.normal(0, 1, DataType.FLOAT, 4, 5), 1);
|
||||
|
|
|
@ -21,9 +21,10 @@
|
|||
package org.nd4j.linalg.api.blas;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -31,15 +32,14 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class LapackTest extends BaseNd4jTest {
|
||||
public LapackTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
public class LapackTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
@Test
|
||||
public void testQRSquare() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testQRSquare(Nd4jBackend backend) {
|
||||
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
A = A.reshape('c', 3, 3);
|
||||
INDArray O = Nd4j.create(A.dataType(), A.shape());
|
||||
|
@ -57,7 +57,9 @@ public class LapackTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testQRRect() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testQRRect(Nd4jBackend backend) {
|
||||
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
A = A.reshape('f', 4, 3);
|
||||
INDArray O = Nd4j.create(A.dataType(), A.shape());
|
||||
|
@ -75,7 +77,9 @@ public class LapackTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCholeskyL() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCholeskyL(Nd4jBackend backend) {
|
||||
INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,});
|
||||
A = A.reshape('c', 3, 3);
|
||||
INDArray O = Nd4j.create(A.dataType(), A.shape());
|
||||
|
@ -92,7 +96,9 @@ public class LapackTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testCholeskyU() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCholeskyU(Nd4jBackend backend) {
|
||||
INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,});
|
||||
A = A.reshape('f', 3, 3);
|
||||
INDArray O = Nd4j.create(A.dataType(), A.shape());
|
||||
|
|
|
@ -22,9 +22,10 @@ package org.nd4j.linalg.api.blas;
|
|||
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -35,14 +36,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public class Level1Test extends BaseNd4jTest {
|
||||
public Level1Test(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
public class Level1Test extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
public void testDot() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDot(Nd4jBackend backend) {
|
||||
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||
assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1);
|
||||
|
@ -55,7 +55,9 @@ public class Level1Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAxpy() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAxpy(Nd4jBackend backend) {
|
||||
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
INDArray row = matrix.getRow(1);
|
||||
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
|
||||
|
@ -64,7 +66,9 @@ public class Level1Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAxpy2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAxpy2(Nd4jBackend backend) {
|
||||
val rowX = Nd4j.create(new double[]{1, 2, 3, 4});
|
||||
val rowY = Nd4j.create(new double[]{1, 2, 3, 4});
|
||||
val exp = Nd4j.create(new double[]{3, 6, 9, 12});
|
||||
|
|
|
@ -21,23 +21,23 @@
|
|||
package org.nd4j.linalg.api.blas;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class Level2Test extends BaseNd4jTest {
|
||||
public Level2Test(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
public class Level2Test extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
public void testGemv1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemv1(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||
|
||||
|
@ -51,7 +51,9 @@ public class Level2Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemv2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemv2(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
|
||||
|
||||
|
@ -65,7 +67,9 @@ public class Level2Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemv3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemv3(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
|
||||
|
||||
|
@ -79,7 +83,9 @@ public class Level2Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemv4() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemv4(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||
|
||||
|
@ -93,7 +99,9 @@ public class Level2Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemv5() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemv5(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||
|
||||
|
@ -109,7 +117,9 @@ public class Level2Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemv6() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemv6(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||
|
||||
|
@ -125,7 +135,9 @@ public class Level2Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemv7() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemv7(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||
|
||||
|
|
|
@ -21,23 +21,23 @@
|
|||
package org.nd4j.linalg.api.blas;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class Level3Test extends BaseNd4jTest {
|
||||
public Level3Test(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
public class Level3Test extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
public void testGemm1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemm1(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||
|
||||
|
@ -47,7 +47,9 @@ public class Level3Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemm2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemm2(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
|
||||
|
||||
|
@ -57,7 +59,9 @@ public class Level3Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemm3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemm3(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
|
||||
|
||||
|
@ -75,7 +79,9 @@ public class Level3Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemm4() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemm4(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);
|
||||
|
||||
|
@ -92,7 +98,9 @@ public class Level3Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemm5() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemm5(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
|
||||
|
||||
|
@ -106,7 +114,9 @@ public class Level3Test extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGemm6() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemm6(Nd4jBackend backend) {
|
||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);
|
||||
|
||||
|
|
|
@ -21,9 +21,10 @@
|
|||
package org.nd4j.linalg.api.blas.params;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
@ -33,16 +34,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public class ParamsTestsF extends BaseNd4jTest {
|
||||
|
||||
|
||||
public ParamsTestsF(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class ParamsTestsF extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
public void testGemm() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGemm (Nd4jBackend backend) {
|
||||
INDArray a = Nd4j.create(2, 2);
|
||||
INDArray b = Nd4j.create(2, 3);
|
||||
INDArray c = Nd4j.create(2, 3);
|
||||
|
|
|
@ -25,9 +25,10 @@ import org.bytedeco.javacpp.*;
|
|||
import org.bytedeco.javacpp.indexer.*;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||
|
@ -45,16 +46,15 @@ import java.nio.ByteOrder;
|
|||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
public class DataBufferTests extends BaseNd4jTest {
|
||||
|
||||
public DataBufferTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class DataBufferTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
@Test
|
||||
@Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657")
|
||||
public void testNoArgCreateBufferFromArray() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNoArgCreateBufferFromArray(Nd4jBackend backend) {
|
||||
|
||||
//Tests here:
|
||||
//1. Create from JVM array
|
||||
|
@ -280,7 +280,9 @@ public class DataBufferTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testCreateTypedBuffer() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testCreateTypedBuffer(Nd4jBackend backend) {
|
||||
|
||||
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
||||
|
@ -350,7 +352,9 @@ public class DataBufferTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAsBytes() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAsBytes(Nd4jBackend backend) {
|
||||
INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1);
|
||||
|
||||
for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16,
|
||||
|
@ -404,7 +408,9 @@ public class DataBufferTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testEnsureLocation(){
|
||||
//https://github.com/eclipse/deeplearning4j/issues/8783
|
||||
Nd4j.create(1);
|
||||
|
|
|
@ -23,9 +23,10 @@ package org.nd4j.linalg.api.buffer;
|
|||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -33,13 +34,10 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class DataTypeValidationTests extends BaseNd4jTest {
|
||||
DataType initialType;
|
||||
|
||||
public DataTypeValidationTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
||||
DataType initialType = Nd4j.dataType();
|
||||
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() {
|
||||
|
@ -48,7 +46,7 @@ public class DataTypeValidationTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@AfterEach
|
||||
public void shutUp() {
|
||||
public void reset() {
|
||||
Nd4j.setDataType(initialType);
|
||||
}
|
||||
|
||||
|
@ -73,7 +71,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
|
|||
* Testing level1 blas
|
||||
*/
|
||||
@Test()
|
||||
public void testBlasValidation1() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBlasValidation1(Nd4jBackend backend) {
|
||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||
INDArray x = Nd4j.create(10);
|
||||
|
||||
|
@ -90,7 +90,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
|
|||
* Testing level2 blas
|
||||
*/
|
||||
@Test()
|
||||
public void testBlasValidation2() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBlasValidation2(Nd4jBackend backend) {
|
||||
assertThrows(RuntimeException.class,() -> {
|
||||
INDArray a = Nd4j.create(100, 10);
|
||||
INDArray x = Nd4j.create(100);
|
||||
|
@ -108,7 +110,9 @@ public class DataTypeValidationTests extends BaseNd4jTest {
|
|||
* Testing level3 blas
|
||||
*/
|
||||
@Test()
|
||||
public void testBlasValidation3() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBlasValidation3(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
INDArray x = Nd4j.create(100, 100);
|
||||
|
||||
|
|
|
@ -26,9 +26,10 @@ import org.bytedeco.javacpp.indexer.Indexer;
|
|||
import org.junit.jupiter.api.*;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||
|
@ -54,34 +55,31 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
|
||||
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
|
||||
public class DoubleDataBufferTest extends BaseNd4jTest {
|
||||
public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
|
||||
DataType initialType;
|
||||
|
||||
public DoubleDataBufferTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
initialType = Nd4j.dataType();
|
||||
}
|
||||
DataType initialType = Nd4j.dataType();
|
||||
|
||||
|
||||
|
||||
@BeforeEach
|
||||
public void before() {
|
||||
public void before(Nd4jBackend backend) {
|
||||
|
||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void after() {
|
||||
public void after(Nd4jBackend backend) {
|
||||
DataTypeUtil.setDTypeForContext(initialType);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPointerCreation() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPointerCreation(Nd4jBackend backend) {
|
||||
DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4);
|
||||
Indexer indexer = DoubleIndexer.create(floatPointer);
|
||||
DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.DOUBLE, 4, indexer);
|
||||
|
@ -89,8 +87,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
assertArrayEquals(other.asDouble(), buffer.asDouble(), 0.001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetSet() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetSet(Nd4jBackend backend) {
|
||||
double[] d1 = new double[] {1, 2, 3, 4};
|
||||
DataBuffer d = Nd4j.createBuffer(d1);
|
||||
double[] d2 = d.asDouble();
|
||||
|
@ -100,10 +100,12 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerialization2() throws Exception {
|
||||
INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10),
|
||||
// Nd4j.ones(5,10).getRow(2)
|
||||
// Nd4j.ones(5,10).getRow(2)
|
||||
};
|
||||
|
||||
for (INDArray a : arr) {
|
||||
|
@ -128,7 +130,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerialization(@TempDir Path testDir) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
DataBuffer buf = Nd4j.createBuffer(5);
|
||||
|
@ -150,8 +154,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testDup() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDup(Nd4jBackend backend) {
|
||||
double[] d1 = new double[] {1, 2, 3, 4};
|
||||
DataBuffer d = Nd4j.createBuffer(d1);
|
||||
DataBuffer d2 = d.dup();
|
||||
|
@ -160,8 +166,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
|
||||
@Test
|
||||
public void testPut() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPut(Nd4jBackend backend) {
|
||||
double[] d1 = new double[] {1, 2, 3, 4};
|
||||
DataBuffer d = Nd4j.createBuffer(d1);
|
||||
d.put(0, 0.0);
|
||||
|
@ -171,8 +179,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testGetRange() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetRange(Nd4jBackend backend) {
|
||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
|
||||
double[] get = buffer.getDoublesAt(0, 3);
|
||||
double[] data = new double[] {1, 2, 3};
|
||||
|
@ -186,8 +196,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetOffsetRange() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetOffsetRange(Nd4jBackend backend) {
|
||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
|
||||
double[] get = buffer.getDoublesAt(1, 3);
|
||||
double[] data = new double[] {2, 3, 4};
|
||||
|
@ -201,8 +213,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAssign() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAssign(Nd4jBackend backend) {
|
||||
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
|
||||
DataBuffer one = Nd4j.createBuffer(new double[] {1});
|
||||
DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3});
|
||||
|
@ -212,8 +226,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testOffset() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOffset(Nd4jBackend backend) {
|
||||
DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2);
|
||||
assertEquals(2, create.length());
|
||||
assertEquals(0, create.offset());
|
||||
|
@ -222,8 +238,10 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReallocation() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReallocation(Nd4jBackend backend) {
|
||||
DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4});
|
||||
assertEquals(4, buffer.capacity());
|
||||
double[] old = buffer.asDouble();
|
||||
|
@ -232,10 +250,12 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
assertArrayEquals(old, Arrays.copyOf(buffer.asDouble(), 4), 1e-1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReallocationWorkspace() {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReallocationWorkspace(Nd4jBackend backend) {
|
||||
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
||||
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
|
||||
|
||||
DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4});
|
||||
|
@ -249,7 +269,9 @@ public class DoubleDataBufferTest extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAddressPointer(){
|
||||
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
|
||||
return;
|
||||
|
|
|
@ -27,7 +27,9 @@ import org.bytedeco.javacpp.indexer.Indexer;
|
|||
import org.junit.jupiter.api.*;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||
|
@ -54,14 +56,9 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
* @author Adam Gibson
|
||||
*/
|
||||
@Disabled("AB 2019/05/21 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
|
||||
public class FloatDataBufferTest extends BaseNd4jTest {
|
||||
public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
DataType initialType;
|
||||
|
||||
public FloatDataBufferTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
initialType = Nd4j.dataType();
|
||||
}
|
||||
DataType initialType = Nd4j.dataType();
|
||||
|
||||
@BeforeEach
|
||||
public void before() {
|
||||
|
@ -76,7 +73,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testPointerCreation() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPointerCreation(Nd4jBackend backend) {
|
||||
FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4);
|
||||
Indexer indexer = FloatIndexer.create(floatPointer);
|
||||
DataBuffer buffer = Nd4j.createBuffer(floatPointer, DataType.FLOAT, 4, indexer);
|
||||
|
@ -85,7 +84,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetSet() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetSet(Nd4jBackend backend) {
|
||||
float[] d1 = new float[] {1, 2, 3, 4};
|
||||
DataBuffer d = Nd4j.createBuffer(d1);
|
||||
float[] d2 = d.asFloat();
|
||||
|
@ -96,7 +97,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSerialization(@TempDir Path tempDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerialization(@TempDir Path tempDir,Nd4jBackend backend) throws Exception {
|
||||
File dir = tempDir.toFile();
|
||||
DataBuffer buf = Nd4j.createBuffer(5);
|
||||
String fileName = "buf.ser";
|
||||
|
@ -117,7 +120,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDup() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testDup(Nd4jBackend backend) {
|
||||
float[] d1 = new float[] {1, 2, 3, 4};
|
||||
DataBuffer d = Nd4j.createBuffer(d1);
|
||||
DataBuffer d2 = d.dup();
|
||||
|
@ -125,7 +130,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testToNio() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testToNio(Nd4jBackend backend) {
|
||||
DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT);
|
||||
assertEquals(4, buff.length());
|
||||
if (buff.allocationMode() == DataBuffer.AllocationMode.HEAP)
|
||||
|
@ -137,7 +144,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPut() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPut(Nd4jBackend backend) {
|
||||
float[] d1 = new float[] {1, 2, 3, 4};
|
||||
DataBuffer d = Nd4j.createBuffer(d1);
|
||||
d.put(0, 0.0);
|
||||
|
@ -148,7 +157,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testGetRange() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetRange(Nd4jBackend backend) {
|
||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
||||
float[] get = buffer.getFloatsAt(0, 3);
|
||||
float[] data = new float[] {1, 2, 3};
|
||||
|
@ -164,7 +175,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testGetOffsetRange() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetOffsetRange(Nd4jBackend backend) {
|
||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
||||
float[] get = buffer.getFloatsAt(1, 3);
|
||||
float[] data = new float[] {2, 3, 4};
|
||||
|
@ -181,7 +194,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testAsBytes() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAsBytes(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.create(5);
|
||||
byte[] d = arr.data().asBytes();
|
||||
assertEquals(4 * 5, d.length,getFailureMessage());
|
||||
|
@ -191,7 +206,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAssign() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAssign(Nd4jBackend backend) {
|
||||
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
|
||||
DataBuffer one = Nd4j.createBuffer(new double[] {1});
|
||||
DataBuffer twoThree = Nd4j.createBuffer(new double[] {2, 3});
|
||||
|
@ -201,7 +218,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReadWrite() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReadWrite(Nd4jBackend backend) throws Exception {
|
||||
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
|
||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(bos);
|
||||
|
@ -215,7 +234,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testOffset() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOffset(Nd4jBackend backend) {
|
||||
DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2);
|
||||
assertEquals(2, create.length());
|
||||
assertEquals(0, create.offset());
|
||||
|
@ -225,7 +246,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReallocation() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReallocation(Nd4jBackend backend) {
|
||||
DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4});
|
||||
assertEquals(4, buffer.capacity());
|
||||
float[] old = buffer.asFloat();
|
||||
|
@ -236,7 +259,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReallocationWorkspace() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReallocationWorkspace(Nd4jBackend backend) {
|
||||
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
||||
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
|
||||
|
@ -253,7 +278,9 @@ public class FloatDataBufferTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testAddressPointer(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAddressPointer(Nd4jBackend backend){
|
||||
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -23,7 +23,9 @@ package org.nd4j.linalg.api.buffer;
|
|||
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
||||
|
@ -37,13 +39,12 @@ import java.util.Arrays;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class IntDataBufferTests extends BaseNd4jTest {
|
||||
public class IntDataBufferTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
public IntDataBufferTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testBasicSerde1() throws Exception {
|
||||
|
||||
|
||||
|
@ -82,7 +83,9 @@ public class IntDataBufferTests extends BaseNd4jTest {
|
|||
*/
|
||||
|
||||
@Test
|
||||
public void testReallocation() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReallocation(Nd4jBackend backend) {
|
||||
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
|
||||
assertEquals(4, buffer.capacity());
|
||||
buffer.reallocate(6);
|
||||
|
@ -94,9 +97,11 @@ public class IntDataBufferTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testReallocationWorkspace() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testReallocationWorkspace(Nd4jBackend backend) {
|
||||
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
||||
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
|
||||
|
||||
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
|
||||
|
|
|
@ -21,9 +21,10 @@
|
|||
package org.nd4j.linalg.api.indexing;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -37,17 +38,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public class IndexingTests extends BaseNd4jTest {
|
||||
public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
public IndexingTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testINDArrayIndexingEqualToRank() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testINDArrayIndexingEqualToRank(Nd4jBackend backend) {
|
||||
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
|
||||
INDArray indexes = Nd4j.create(new double[][]{
|
||||
{0,1,2},
|
||||
|
@ -62,7 +61,9 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testINDArrayIndexingLessThanRankSimple() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testINDArrayIndexingLessThanRankSimple(Nd4jBackend backend) {
|
||||
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
|
||||
INDArray indexes = Nd4j.create(new double[][]{
|
||||
{0},
|
||||
|
@ -76,7 +77,9 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testINDArrayIndexingLessThanRankFourDimension() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testINDArrayIndexingLessThanRankFourDimension(Nd4jBackend backend) {
|
||||
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
|
||||
INDArray indexes = Nd4j.create(new double[][]{
|
||||
{0},{1}
|
||||
|
@ -89,7 +92,9 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPutSimple() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPutSimple(Nd4jBackend backend) {
|
||||
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2);
|
||||
INDArray indexes = Nd4j.create(new double[][]{
|
||||
{0},{1}
|
||||
|
@ -101,7 +106,9 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetScalar() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetScalar(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
||||
INDArray d = arr.get(NDArrayIndex.point(1));
|
||||
assertTrue(d.isScalar());
|
||||
|
@ -110,14 +117,18 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNewAxis() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNewAxis(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
||||
INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
|
||||
// System.out.println(view);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorIndexing() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testVectorIndexing(Nd4jBackend backend) {
|
||||
INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE);
|
||||
int[] index = new int[] {5, 8, 9};
|
||||
INDArray columnsTest = x.getColumns(index);
|
||||
|
@ -129,7 +140,9 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetRowsColumnsMatrix() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetRowsColumnsMatrix(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6);
|
||||
INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}});
|
||||
|
||||
|
@ -147,7 +160,9 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testSlicing() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSlicing(Nd4jBackend backend) {
|
||||
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
|
||||
INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14});
|
||||
INDArray slice1Test = arange.slice(1);
|
||||
|
@ -155,7 +170,9 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testArangeMul() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testArangeMul(Nd4jBackend backend) {
|
||||
INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE);
|
||||
INDArrayIndex index = NDArrayIndex.interval(0, 2);
|
||||
INDArray get = arange.get(index, index);
|
||||
|
@ -167,7 +184,9 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetIndicesVector() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetIndicesVector(Nd4jBackend backend) {
|
||||
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
|
||||
INDArray test = Nd4j.create(new double[] {2, 3});
|
||||
INDArray result = line.get(NDArrayIndex.point(0), NDArrayIndex.interval(1, 3));
|
||||
|
@ -175,7 +194,9 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testGetIndicesVectorView() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetIndicesVectorView(Nd4jBackend backend) {
|
||||
INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5);
|
||||
INDArray column = matrix.getColumn(0).reshape(1,5);
|
||||
INDArray test = Nd4j.create(new double[] {6, 11});
|
||||
|
@ -193,7 +214,9 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void test2dGetPoint(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void test2dGetPoint(Nd4jBackend backend){
|
||||
INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4);
|
||||
for( int i=0; i<3; i++ ){
|
||||
INDArray exp = Nd4j.create(new double[]{i*4+1, i*4+2, i*4+3, i*4+4});
|
||||
|
@ -206,7 +229,7 @@ public class IndexingTests extends BaseNd4jTest {
|
|||
assertEquals(exp, get);
|
||||
}
|
||||
|
||||
for( int i=0; i<4; i++ ){
|
||||
for( int i = 0; i < 4; i++) {
|
||||
INDArray exp = Nd4j.create(new double[]{1+i, 5+i, 9+i});
|
||||
INDArray col = arr.getColumn(i);
|
||||
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(i));
|
||||
|
|
|
@ -21,10 +21,11 @@
|
|||
package org.nd4j.linalg.api.indexing;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -49,16 +50,15 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
|
|||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public class IndexingTestsC extends BaseNd4jTest {
|
||||
|
||||
public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
public IndexingTestsC(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNegativeBounds() {
|
||||
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
|
||||
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
|
||||
|
@ -70,7 +70,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(assertion,get);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNewAxis() {
|
||||
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
|
||||
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
|
||||
|
@ -79,7 +81,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void broadcastBug() {
|
||||
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
|
||||
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
||||
|
@ -90,7 +94,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIntervalsIn3D() {
|
||||
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
||||
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
||||
|
@ -99,7 +105,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSmallInterval() {
|
||||
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
||||
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
||||
|
@ -108,7 +116,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAllWithNewAxisAndInterval() {
|
||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
|
||||
|
@ -117,7 +127,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(assertion2, get2);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAllWithNewAxisInMiddle() {
|
||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
|
||||
|
@ -126,7 +138,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(assertion2, get2);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testAllWithNewAxis() {
|
||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||
INDArray get = arr.get(newAxis(), all(), point(1));
|
||||
|
@ -136,7 +150,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIndexingWithMmul() {
|
||||
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
|
||||
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
||||
|
@ -147,7 +163,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(assertion, c);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPointPointInterval() {
|
||||
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
|
||||
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
|
||||
|
@ -156,7 +174,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(assertion, get);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIntervalLowerBound() {
|
||||
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
|
||||
|
@ -167,7 +187,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetPointRowVector() {
|
||||
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
|
||||
|
||||
|
@ -177,7 +199,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSpecifiedIndexVector() {
|
||||
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
|
||||
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
|
||||
|
@ -194,7 +218,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testPutRowIndexing() {
|
||||
INDArray arr = Nd4j.ones(1, 10);
|
||||
INDArray row = Nd4j.create(1, 10);
|
||||
|
@ -204,7 +230,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(arr, row);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testVectorIndexing2() {
|
||||
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
|
||||
INDArray assertion = Nd4j.create(new double[] {2, 4});
|
||||
|
@ -219,7 +247,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testOffsetsC() {
|
||||
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
|
||||
|
@ -235,7 +265,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIndexFor() {
|
||||
long[] shape = {1, 2};
|
||||
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
|
||||
|
@ -244,7 +276,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetScalar() {
|
||||
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
||||
INDArray d = arr.get(point(1));
|
||||
|
@ -253,7 +287,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testVectorIndexing() {
|
||||
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
|
||||
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
|
||||
|
@ -261,14 +297,18 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(assertion, viewTest);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNegativeIndices() {
|
||||
INDArray test = Nd4j.create(10, 10, 10);
|
||||
test.putScalar(new int[] {0, 0, -1}, 1.0);
|
||||
assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetIndices2d() {
|
||||
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
|
||||
INDArray firstRow = twoByTwo.getRow(0);
|
||||
|
@ -286,7 +326,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetRow() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
|
||||
|
@ -303,7 +345,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetRowEdgeCase() {
|
||||
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
||||
INDArray get = rowVec.getRow(0); //Returning shape [1,1]
|
||||
|
@ -312,7 +356,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(rowVec, get);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetColumnEdgeCase() {
|
||||
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
|
||||
INDArray get = colVec.getColumn(0); //Returning shape [1,1]
|
||||
|
@ -321,7 +367,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(colVec, get);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testConcatColumns() {
|
||||
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
|
||||
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
|
||||
|
@ -330,7 +378,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(assertion, concat);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testGetIndicesVector() {
|
||||
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
|
||||
INDArray test = Nd4j.create(new double[] {2, 3});
|
||||
|
@ -338,7 +388,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(test, result);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testArangeMul() {
|
||||
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
|
||||
INDArrayIndex index = interval(0, 2);
|
||||
|
@ -349,7 +401,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
assertEquals(assertion, mul);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIndexingThorough(){
|
||||
long[] fullShape = {3,4,5,6,7};
|
||||
|
||||
|
@ -549,7 +603,9 @@ public class IndexingTestsC extends BaseNd4jTest {
|
|||
return d;
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void debugging(){
|
||||
long[] inShape = {3,4};
|
||||
INDArrayIndex[] indexes = new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(1, 2, 4)};
|
||||
|
|
|
@ -21,9 +21,10 @@
|
|||
package org.nd4j.linalg.api.indexing.resolve;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
@ -36,15 +37,14 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public class NDArrayIndexResolveTests extends BaseNd4jTest {
|
||||
|
||||
public NDArrayIndexResolveTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class NDArrayIndexResolveTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
@Test
|
||||
public void testResolvePoint() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testResolvePoint(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
|
||||
INDArrayIndex[] test = NDArrayIndex.resolve(arr.shape(), NDArrayIndex.point(1));
|
||||
INDArrayIndex[] assertion = {NDArrayIndex.point(1), NDArrayIndex.all()};
|
||||
|
@ -59,6 +59,8 @@ public class NDArrayIndexResolveTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testResolvePointVector() {
|
||||
INDArray arr = Nd4j.linspace(1, 4, 4);
|
||||
INDArrayIndex[] getPoint = {NDArrayIndex.point(1)};
|
||||
|
|
|
@ -21,9 +21,10 @@
|
|||
package org.nd4j.linalg.api.indexing.shape;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||
import org.nd4j.linalg.indexing.Indices;
|
||||
|
@ -34,19 +35,15 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
|||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public class IndexShapeTests extends BaseNd4jTest {
|
||||
|
||||
public IndexShapeTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
|
||||
public class IndexShapeTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
private int[] shape = {1, 1, 2, 1, 3, 4, 5, 1};
|
||||
|
||||
@Test
|
||||
public void testSinglePoint() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSinglePoint(Nd4jBackend backend) {
|
||||
/*
|
||||
Assumes all indexes are filled out.
|
||||
Test simple general point case
|
||||
|
@ -77,7 +74,9 @@ public class IndexShapeTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testInterval() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testInterval(Nd4jBackend backend) {
|
||||
int[] basicAssertion = {1, 1, 1, 1, 3, 1, 2, 1};
|
||||
INDArrayIndex[] basicTest = {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 1),
|
||||
NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(1, 2),
|
||||
|
@ -88,7 +87,9 @@ public class IndexShapeTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testNewAxis() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNewAxis(Nd4jBackend backend) {
|
||||
//normal prepend
|
||||
int[] prependAssertion = {1, 1, 1, 1, 2, 1, 3, 4, 5, 1};
|
||||
INDArrayIndex[] prependTest = {NDArrayIndex.newAxis(), NDArrayIndex.newAxis(), NDArrayIndex.all(),
|
||||
|
|
|
@ -21,9 +21,10 @@
|
|||
package org.nd4j.linalg.api.indexing.shape;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.indexing.Indices;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
@ -33,25 +34,26 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
|||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public class IndexShapeTests2d extends BaseNd4jTest {
|
||||
|
||||
public IndexShapeTests2d(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class IndexShapeTests2d extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
private long[] shape = {3, 2};
|
||||
|
||||
|
||||
@Test
|
||||
public void test2dCases() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void test2dCases(Nd4jBackend backend) {
|
||||
assertArrayEquals(new long[] {1, 2}, Indices.shape(shape, NDArrayIndex.point(1)));
|
||||
assertArrayEquals(new long[] {3, 1},
|
||||
Indices.shape(shape, NDArrayIndex.all(), NDArrayIndex.point(1)));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNewAxis2d() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNewAxis2d(Nd4jBackend backend) {
|
||||
assertArrayEquals(new long[] {1, 3, 2}, Indices.shape(shape,
|
||||
NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.all()));
|
||||
assertArrayEquals(new long[] {3, 1, 2}, Indices.shape(shape,
|
||||
|
|
|
@ -22,9 +22,10 @@ package org.nd4j.linalg.api.iterator;
|
|||
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
|
@ -33,15 +34,14 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
|||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public class NDIndexIteratorTest extends BaseNd4jTest {
|
||||
|
||||
public NDIndexIteratorTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class NDIndexIteratorTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
@Test
|
||||
public void testIterate() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testIterate(Nd4jBackend backend) {
|
||||
val shapeIter = new NdIndexIterator(2, 2);
|
||||
val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},};
|
||||
|
||||
|
|
|
@ -28,9 +28,10 @@ import org.apache.commons.lang3.ArrayUtils;
|
|||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
@ -45,18 +46,15 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestNdArrReadWriteTxt extends BaseNd4jTest {
|
||||
|
||||
|
||||
public TestNdArrReadWriteTxt(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
public void compareAfterWrite(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
int [] ranksToCheck = new int[] {0,1,2,3,4};
|
||||
for (int i=0; i<ranksToCheck.length;i++) {
|
||||
for (int i = 0; i < ranksToCheck.length; i++) {
|
||||
// log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
||||
compareArrays(ranksToCheck[i],ordering(), testDir);
|
||||
}
|
||||
|
@ -84,7 +82,9 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testNd4jReadWriteText(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNd4jReadWriteText(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
|
||||
File dir = testDir.toFile();
|
||||
int count = 0;
|
||||
|
|
|
@ -25,9 +25,10 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import java.nio.file.Path;
|
||||
|
@ -35,17 +36,14 @@ import java.nio.file.Path;
|
|||
import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestNdArrReadWriteTxtC extends BaseNd4jTest {
|
||||
|
||||
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
|
||||
|
||||
public TestNdArrReadWriteTxtC(Nd4jBackend backend) {
|
||||
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void compareAfterWrite(@TempDir Path testDir) throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
|
||||
for (int i = 0; i < ranksToCheck.length; i++) {
|
||||
log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
||||
|
|
|
@ -21,9 +21,10 @@
|
|||
package org.nd4j.linalg.api.ndarray;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
@ -32,16 +33,14 @@ import java.io.*;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestSerialization extends BaseNd4jTest {
|
||||
|
||||
public TestSerialization(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class TestSerialization extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
@Test
|
||||
public void testSerializationFullArrayNd4jWriteRead() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
|
||||
int length = 100;
|
||||
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
|
||||
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
|
||||
|
@ -71,7 +70,9 @@ public class TestSerialization extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSerializationFullArrayJava() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationFullArrayJava(Nd4jBackend backend) throws Exception {
|
||||
int length = 100;
|
||||
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
|
||||
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
|
||||
|
@ -102,7 +103,9 @@ public class TestSerialization extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSerializationOnViewsNd4jWriteRead() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationOnViewsNd4jWriteRead(Nd4jBackend backend) throws Exception {
|
||||
int length = 100;
|
||||
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
|
||||
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
|
||||
|
@ -138,7 +141,9 @@ public class TestSerialization extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSerializationOnViewsJava() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationOnViewsJava(Nd4jBackend backend) throws Exception {
|
||||
int length = 100;
|
||||
INDArray arrC = Nd4j.linspace(1, length, length).reshape('c', 10, 10);
|
||||
INDArray arrF = Nd4j.linspace(1, length, length).reshape('f', 10, 10);
|
||||
|
|
|
@ -24,9 +24,10 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -39,15 +40,11 @@ import java.io.*;
|
|||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestSerializationDoubleToFloat extends BaseNd4jTest {
|
||||
|
||||
DataType initialType;
|
||||
public class TestSerializationDoubleToFloat extends BaseNd4jTestWithBackends {
|
||||
|
||||
DataType initialType = Nd4j.dataType();
|
||||
|
||||
public TestSerializationDoubleToFloat(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
this.initialType = Nd4j.dataType();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void after() {
|
||||
|
@ -55,7 +52,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSerializationFullArrayNd4jWriteRead() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
|
||||
int length = 4;
|
||||
|
||||
//WRITE OUT A DOUBLE ARRAY
|
||||
|
@ -93,7 +92,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSerializationFullArrayJava() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationFullArrayJava(Nd4jBackend backend) throws Exception {
|
||||
int length = 100;
|
||||
Nd4j.create(1);
|
||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||
|
@ -123,7 +124,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSerializationOnViewsNd4jWriteRead() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationOnViewsNd4jWriteRead(Nd4jBackend backend) throws Exception {
|
||||
int length = 100;
|
||||
Nd4j.create(1);
|
||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||
|
@ -153,7 +156,9 @@ public class TestSerializationDoubleToFloat extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testSerializationOnViewsJava() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationOnViewsJava(Nd4jBackend backend) throws Exception {
|
||||
int length = 100;
|
||||
Nd4j.create(1);
|
||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||
|
|
|
@ -22,9 +22,10 @@ package org.nd4j.linalg.api.ndarray;
|
|||
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -37,23 +38,21 @@ import java.io.*;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestSerializationFloatToDouble extends BaseNd4jTest {
|
||||
|
||||
DataType initialType;
|
||||
public class TestSerializationFloatToDouble extends BaseNd4jTestWithBackends {
|
||||
|
||||
DataType initialType = Nd4j.dataType();
|
||||
|
||||
public TestSerializationFloatToDouble(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
this.initialType = Nd4j.dataType();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void after() {
|
||||
Nd4j.setDataType(this.initialType);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSerializationFullArrayNd4jWriteRead() throws Exception {
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationFullArrayNd4jWriteRead(Nd4jBackend backend) throws Exception {
|
||||
int length = 100;
|
||||
|
||||
//WRITE OUT A FLOAT ARRAY
|
||||
|
@ -85,7 +84,9 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
|
|||
assertTrue(Transforms.abs(arr1.sub(arr2).div(arr1)).maxNumber().doubleValue() < 0.01);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationFullArrayJava() throws Exception {
|
||||
int length = 100;
|
||||
Nd4j.create(1);
|
||||
|
@ -116,7 +117,9 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
|
|||
assertTrue(Transforms.abs(arr1.sub(arr2).div(arr1)).maxNumber().doubleValue() < 0.01);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationOnViewsNd4jWriteRead() throws Exception {
|
||||
int length = 100;
|
||||
Nd4j.create(1);
|
||||
|
@ -146,7 +149,9 @@ public class TestSerializationFloatToDouble extends BaseNd4jTest {
|
|||
assertTrue(Transforms.abs(sub1.sub(arr2).div(sub1)).maxNumber().doubleValue() < 0.01);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testSerializationOnViewsJava() throws Exception {
|
||||
int length = 100;
|
||||
Nd4j.create(1);
|
||||
|
|
|
@ -21,9 +21,10 @@
|
|||
package org.nd4j.linalg.api.rng;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
@ -33,14 +34,13 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@RunWith(Parameterized.class)
|
||||
public class RngTests extends BaseNd4jTest {
|
||||
public RngTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
public class RngTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
@Test
|
||||
public void testRngConstitency() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRngConstitency(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(123);
|
||||
INDArray arr = Nd4j.rand(1, 5);
|
||||
Nd4j.getRandom().setSeed(123);
|
||||
|
@ -49,7 +49,9 @@ public class RngTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRandomWithOrder() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRandomWithOrder(Nd4jBackend backend) {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -105,7 +107,9 @@ public class RngTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRandomBinomial() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRandomBinomial(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
//silly tests. Just increasing the usage for randomBinomial to stop compiler warnings.
|
||||
INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3);
|
||||
|
|
|
@ -23,9 +23,10 @@ package org.nd4j.linalg.api.string;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
@ -35,22 +36,23 @@ import org.nd4j.linalg.string.NDArrayStrings;
|
|||
* @author Adam Gibson
|
||||
*/
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestFormatting extends BaseNd4jTest {
|
||||
|
||||
public TestFormatting(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
public class TestFormatting extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
@Test
|
||||
public void testTwoByTwo() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testTwoByTwo(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.create(2, 2, 2, 2);
|
||||
System.out.println(new NDArrayStrings().format(arr));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNd4jArrayString() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testNd4jArrayString(Nd4jBackend backend) {
|
||||
|
||||
INDArray arr = Nd4j.create(new float[]{1f, 20000000f, 40.838383f, 3f}, new int[]{2, 2});
|
||||
|
||||
|
@ -71,7 +73,9 @@ public class TestFormatting extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testRange() {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
||||
public void testRange(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.create(new double[][]{
|
||||
{-1,0,1,0},
|
||||
{-0.1, 0.1, -10, 10},
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue