More junit 4 removal, all tests compile. FIxed parameterized test invocation. Deleted nd4j-parameter-server-status that used play
parent
3c6014271e
commit
e0077c38a9
|
@ -34,8 +34,9 @@ import java.net.URISyntaxException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
|
|
@ -32,9 +32,7 @@ import org.junit.jupiter.api.Test;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
|
||||||
|
|
||||||
public class PartitionerTests extends BaseND4JTest {
|
public class PartitionerTests extends BaseND4JTest {
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -58,7 +58,7 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.joda.time.DateTimeFieldType;
|
import org.joda.time.DateTimeFieldType;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -71,7 +71,7 @@ import java.io.ObjectOutputStream;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestTransforms extends BaseND4JTest {
|
public class TestTransforms extends BaseND4JTest {
|
||||||
|
@ -277,22 +277,22 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
|
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
|
||||||
outputColumns.add(NEW_COLUMN);
|
outputColumns.add(NEW_COLUMN);
|
||||||
Schema newSchema = transform.transform(schema);
|
Schema newSchema = transform.transform(schema);
|
||||||
Assert.assertEquals(outputColumns, newSchema.getColumnNames());
|
assertEquals(outputColumns, newSchema.getColumnNames());
|
||||||
|
|
||||||
List<Writable> input = new ArrayList<>();
|
List<Writable> input = new ArrayList<>();
|
||||||
input.addAll(COLUMN_VALUES);
|
input.addAll(COLUMN_VALUES);
|
||||||
|
|
||||||
transform.setInputSchema(schema);
|
transform.setInputSchema(schema);
|
||||||
List<Writable> transformed = transform.map(input);
|
List<Writable> transformed = transform.map(input);
|
||||||
Assert.assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString());
|
assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString());
|
||||||
|
|
||||||
List<Text> outputColumnValues = new ArrayList<>(COLUMN_VALUES);
|
List<Text> outputColumnValues = new ArrayList<>(COLUMN_VALUES);
|
||||||
outputColumnValues.add(new Text(NEW_COLUMN_VALUE));
|
outputColumnValues.add(new Text(NEW_COLUMN_VALUE));
|
||||||
Assert.assertEquals(outputColumnValues, transformed);
|
assertEquals(outputColumnValues, transformed);
|
||||||
|
|
||||||
String s = JsonMappers.getMapper().writeValueAsString(transform);
|
String s = JsonMappers.getMapper().writeValueAsString(transform);
|
||||||
Transform transform2 = JsonMappers.getMapper().readValue(s, ConcatenateStringColumns.class);
|
Transform transform2 = JsonMappers.getMapper().readValue(s, ConcatenateStringColumns.class);
|
||||||
Assert.assertEquals(transform, transform2);
|
assertEquals(transform, transform2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -309,7 +309,7 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
transform.setInputSchema(schema);
|
transform.setInputSchema(schema);
|
||||||
Schema newSchema = transform.transform(schema);
|
Schema newSchema = transform.transform(schema);
|
||||||
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
|
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
|
||||||
Assert.assertEquals(outputColumns, newSchema.getColumnNames());
|
assertEquals(outputColumns, newSchema.getColumnNames());
|
||||||
|
|
||||||
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.LOWER);
|
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.LOWER);
|
||||||
transform.setInputSchema(schema);
|
transform.setInputSchema(schema);
|
||||||
|
@ -320,8 +320,8 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
output.add(new Text(TEXT_LOWER_CASE));
|
output.add(new Text(TEXT_LOWER_CASE));
|
||||||
output.add(new Text(TEXT_MIXED_CASE));
|
output.add(new Text(TEXT_MIXED_CASE));
|
||||||
List<Writable> transformed = transform.map(input);
|
List<Writable> transformed = transform.map(input);
|
||||||
Assert.assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE);
|
assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE);
|
||||||
Assert.assertEquals(transformed, output);
|
assertEquals(transformed, output);
|
||||||
|
|
||||||
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.UPPER);
|
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.UPPER);
|
||||||
transform.setInputSchema(schema);
|
transform.setInputSchema(schema);
|
||||||
|
@ -329,12 +329,12 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
output.add(new Text(TEXT_UPPER_CASE));
|
output.add(new Text(TEXT_UPPER_CASE));
|
||||||
output.add(new Text(TEXT_MIXED_CASE));
|
output.add(new Text(TEXT_MIXED_CASE));
|
||||||
transformed = transform.map(input);
|
transformed = transform.map(input);
|
||||||
Assert.assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE);
|
assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE);
|
||||||
Assert.assertEquals(transformed, output);
|
assertEquals(transformed, output);
|
||||||
|
|
||||||
String s = JsonMappers.getMapper().writeValueAsString(transform);
|
String s = JsonMappers.getMapper().writeValueAsString(transform);
|
||||||
Transform transform2 = JsonMappers.getMapper().readValue(s, ChangeCaseStringTransform.class);
|
Transform transform2 = JsonMappers.getMapper().readValue(s, ChangeCaseStringTransform.class);
|
||||||
Assert.assertEquals(transform, transform2);
|
assertEquals(transform, transform2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -1530,7 +1530,7 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
|
|
||||||
String json = JsonMappers.getMapper().writeValueAsString(t);
|
String json = JsonMappers.getMapper().writeValueAsString(t);
|
||||||
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class);
|
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class);
|
||||||
Assert.assertEquals(t, transform2);
|
assertEquals(t, transform2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1551,7 +1551,7 @@ public class TestTransforms extends BaseND4JTest {
|
||||||
|
|
||||||
String json = JsonMappers.getMapper().writeValueAsString(t);
|
String json = JsonMappers.getMapper().writeValueAsString(t);
|
||||||
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToIndicesNDArrayTransform.class);
|
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToIndicesNDArrayTransform.class);
|
||||||
Assert.assertEquals(t, transform2);
|
assertEquals(t, transform2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -54,10 +54,8 @@ import java.io.FileOutputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import static java.nio.channels.Channels.newChannel;
|
import static java.nio.channels.Channels.newChannel;
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
|
@ -42,7 +42,7 @@ import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
|
||||||
import static org.datavec.api.transform.schema.Schema.Builder;
|
import static org.datavec.api.transform.schema.Schema.Builder;
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
|
|
@ -40,8 +40,9 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class NormalizationTests extends BaseSparkTest {
|
public class NormalizationTests extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
|
@ -165,9 +165,8 @@
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<version>${maven-surefire-plugin.version}</version>
|
|
||||||
<configuration>
|
<configuration>
|
||||||
<argLine> "</argLine>
|
<argLine></argLine>
|
||||||
|
|
||||||
<!--
|
<!--
|
||||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@Disabled
|
@Disabled
|
||||||
|
|
|
@ -27,12 +27,11 @@ import org.nd4j.evaluation.curves.RocCurve;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import static junit.framework.TestCase.assertNull;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@DisplayName("Eval Json Test")
|
@DisplayName("Eval Json Test")
|
||||||
class EvalJsonTest extends BaseDL4JTest {
|
class EvalJsonTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,6 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||||
return 90000L;
|
return 90000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
public void testYoloOutputLayer(CNN2DFormat format) {
|
public void testYoloOutputLayer(CNN2DFormat format) {
|
||||||
|
@ -180,7 +179,6 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
public void yoloGradientCheckRealData(@TempDir Path testDir,CNN2DFormat format) throws Exception {
|
public void yoloGradientCheckRealData(@TempDir Path testDir,CNN2DFormat format) throws Exception {
|
||||||
|
|
|
@ -150,7 +150,6 @@ class BidirectionalTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@DisplayName("Compare Implementations Comp Graph")
|
@DisplayName("Compare Implementations Comp Graph")
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
void compareImplementationsCompGraph(RNNFormat rnnFormat) {
|
void compareImplementationsCompGraph(RNNFormat rnnFormat) {
|
||||||
|
|
|
@ -55,7 +55,6 @@ class MaskZeroLayerTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@DisplayName("Activate")
|
@DisplayName("Activate")
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
void activate(RNNFormat rnnDataFormat) {
|
void activate(RNNFormat rnnDataFormat) {
|
||||||
|
@ -96,7 +95,6 @@ class MaskZeroLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
||||||
@DisplayName("Test Serialization")
|
@DisplayName("Test Serialization")
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
void testSerialization(RNNFormat rnnDataFormat) {
|
void testSerialization(RNNFormat rnnDataFormat) {
|
||||||
|
|
|
@ -109,7 +109,6 @@ public class RnnDataFormatTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
public void testLSTM(boolean helpers,
|
public void testLSTM(boolean helpers,
|
||||||
|
|
|
@ -62,7 +62,6 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
|
||||||
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
public void testLastTimeStepVertex(RNNFormat rnnDataFormat) {
|
public void testLastTimeStepVertex(RNNFormat rnnDataFormat) {
|
||||||
|
@ -127,7 +126,6 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
|
||||||
TestUtils.testModelSerialization(graph);
|
TestUtils.testModelSerialization(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
public void testMaskingAndAllMasked(RNNFormat rnnDataFormat) {
|
public void testMaskingAndAllMasked(RNNFormat rnnDataFormat) {
|
||||||
|
|
|
@ -65,7 +65,6 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat) {
|
public void testTimeStepIs3Dimensional(RNNFormat rnnDataFormat) {
|
||||||
|
@ -117,7 +116,6 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat){
|
public void testDropoutRecurrentLayers(RNNFormat rnnDataFormat){
|
||||||
|
@ -217,7 +215,6 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat){
|
public void testMismatchedInputLabelLength(RNNFormat rnnDataFormat){
|
||||||
|
|
|
@ -55,7 +55,6 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
||||||
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
public void testSimpleRnn(RNNFormat rnnDataFormat) {
|
public void testSimpleRnn(RNNFormat rnnDataFormat) {
|
||||||
|
@ -126,7 +125,6 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
public void testBiasInit(RNNFormat rnnDataFormat) {
|
public void testBiasInit(RNNFormat rnnDataFormat) {
|
||||||
|
|
|
@ -61,7 +61,6 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
||||||
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
return Arrays.asList(RNNFormat.values()).stream().map(Arguments::of);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("#params")
|
@MethodSource("#params")
|
||||||
public void testTimeDistributed(RNNFormat rnnDataFormat){
|
public void testTimeDistributed(RNNFormat rnnDataFormat){
|
||||||
|
|
|
@ -26,8 +26,8 @@ import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import oshi.json.SystemInfo;
|
import oshi.json.SystemInfo;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertNotNull;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||||
|
|
||||||
@Disabled("AB 2019/05/24 - Failing on CI - \"Could not initialize class oshi.jna.platform.linux.Libc\" - Issue #7657")
|
@Disabled("AB 2019/05/24 - Failing on CI - \"Could not initialize class oshi.jna.platform.linux.Libc\" - Issue #7657")
|
||||||
public class TestHardWareMetric extends BaseDL4JTest {
|
public class TestHardWareMetric extends BaseDL4JTest {
|
||||||
|
|
|
@ -45,9 +45,9 @@ import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
import org.nd4j.common.resources.Resources;
|
import org.nd4j.common.resources.Resources;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.Assume.assumeNotNull;
|
import static org.junit.jupiter.api.Assumptions.*;
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
@ -67,11 +67,11 @@ class ModelGuesserTest extends BaseDL4JTest {
|
||||||
File f = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5");
|
File f = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5");
|
||||||
assertTrue(f.exists());
|
assertTrue(f.exists());
|
||||||
Model guess1 = ModelGuesser.loadModelGuess(f.getAbsolutePath());
|
Model guess1 = ModelGuesser.loadModelGuess(f.getAbsolutePath());
|
||||||
assumeNotNull(guess1);
|
assertNotNull(guess1);
|
||||||
f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5");
|
f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5");
|
||||||
assertTrue(f.exists());
|
assertTrue(f.exists());
|
||||||
Model guess2 = ModelGuesser.loadModelGuess(f.getAbsolutePath());
|
Model guess2 = ModelGuesser.loadModelGuess(f.getAbsolutePath());
|
||||||
assumeNotNull(guess2);
|
assertNotNull(guess2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -81,13 +81,13 @@ class ModelGuesserTest extends BaseDL4JTest {
|
||||||
assertTrue(f.exists());
|
assertTrue(f.exists());
|
||||||
try (InputStream inputStream = new FileInputStream(f)) {
|
try (InputStream inputStream = new FileInputStream(f)) {
|
||||||
Model guess1 = ModelGuesser.loadModelGuess(inputStream);
|
Model guess1 = ModelGuesser.loadModelGuess(inputStream);
|
||||||
assumeNotNull(guess1);
|
assertNotNull(guess1);
|
||||||
}
|
}
|
||||||
f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5");
|
f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5");
|
||||||
assertTrue(f.exists());
|
assertTrue(f.exists());
|
||||||
try (InputStream inputStream = new FileInputStream(f)) {
|
try (InputStream inputStream = new FileInputStream(f)) {
|
||||||
Model guess1 = ModelGuesser.loadModelGuess(inputStream);
|
Model guess1 = ModelGuesser.loadModelGuess(inputStream);
|
||||||
assumeNotNull(guess1);
|
assertNotNull(guess1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,7 +157,7 @@ class ModelGuesserTest extends BaseDL4JTest {
|
||||||
ModelSerializer.writeModel(net, tempFile, true);
|
ModelSerializer.writeModel(net, tempFile, true);
|
||||||
try (InputStream inputStream = new FileInputStream(tempFile)) {
|
try (InputStream inputStream = new FileInputStream(tempFile)) {
|
||||||
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream);
|
MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream);
|
||||||
assumeNotNull(network);
|
assertNotNull(network);
|
||||||
assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson());
|
assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson());
|
||||||
assertEquals(net.params(), network.params());
|
assertEquals(net.params(), network.params());
|
||||||
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
|
||||||
|
|
|
@ -52,7 +52,7 @@ import java.util.zip.ZipFile;
|
||||||
import java.util.zip.ZipInputStream;
|
import java.util.zip.ZipInputStream;
|
||||||
import java.util.zip.ZipOutputStream;
|
import java.util.zip.ZipOutputStream;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class ModelValidatorTests extends BaseDL4JTest {
|
public class ModelValidatorTests extends BaseDL4JTest {
|
||||||
|
@ -91,7 +91,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
|
||||||
ValidationResult vr2 = DL4JModelValidator.validateMultiLayerNetwork(f2);
|
ValidationResult vr2 = DL4JModelValidator.validateMultiLayerNetwork(f2);
|
||||||
assertFalse(vr2.isValid());
|
assertFalse(vr2.isValid());
|
||||||
String s = vr2.getIssues().get(0);
|
String s = vr2.getIssues().get(0);
|
||||||
assertTrue(s, s.contains("zip") && s.contains("corrupt"));
|
assertTrue(s.contains("zip") && s.contains("corrupt"), s);
|
||||||
assertEquals("MultiLayerNetwork", vr2.getFormatType());
|
assertEquals("MultiLayerNetwork", vr2.getFormatType());
|
||||||
assertEquals(MultiLayerNetwork.class, vr2.getFormatClass());
|
assertEquals(MultiLayerNetwork.class, vr2.getFormatClass());
|
||||||
assertNotNull(vr2.getException());
|
assertNotNull(vr2.getException());
|
||||||
|
@ -108,7 +108,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
|
||||||
assertFalse(vr3.isValid());
|
assertFalse(vr3.isValid());
|
||||||
s = vr3.getIssues().get(0);
|
s = vr3.getIssues().get(0);
|
||||||
assertEquals(1, vr3.getIssues().size());
|
assertEquals(1, vr3.getIssues().size());
|
||||||
assertTrue(s, s.contains("missing") && s.contains("configuration"));
|
assertTrue(s.contains("missing") && s.contains("configuration"), s);
|
||||||
assertEquals("MultiLayerNetwork", vr3.getFormatType());
|
assertEquals("MultiLayerNetwork", vr3.getFormatType());
|
||||||
assertEquals(MultiLayerNetwork.class, vr3.getFormatClass());
|
assertEquals(MultiLayerNetwork.class, vr3.getFormatClass());
|
||||||
assertNull(vr3.getException());
|
assertNull(vr3.getException());
|
||||||
|
@ -126,7 +126,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
|
||||||
assertFalse(vr4.isValid());
|
assertFalse(vr4.isValid());
|
||||||
s = vr4.getIssues().get(0);
|
s = vr4.getIssues().get(0);
|
||||||
assertEquals(1, vr4.getIssues().size());
|
assertEquals(1, vr4.getIssues().size());
|
||||||
assertTrue(s, s.contains("missing") && s.contains("coefficients"));
|
assertTrue(s.contains("missing") && s.contains("coefficients"), s);
|
||||||
assertEquals("MultiLayerNetwork", vr4.getFormatType());
|
assertEquals("MultiLayerNetwork", vr4.getFormatType());
|
||||||
assertEquals(MultiLayerNetwork.class, vr4.getFormatClass());
|
assertEquals(MultiLayerNetwork.class, vr4.getFormatClass());
|
||||||
assertNull(vr4.getException());
|
assertNull(vr4.getException());
|
||||||
|
@ -169,7 +169,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
|
||||||
assertFalse(vr6.isValid());
|
assertFalse(vr6.isValid());
|
||||||
s = vr6.getIssues().get(0);
|
s = vr6.getIssues().get(0);
|
||||||
assertEquals(1, vr6.getIssues().size());
|
assertEquals(1, vr6.getIssues().size());
|
||||||
assertTrue(s, s.contains("JSON") && s.contains("valid") && s.contains("MultiLayerConfiguration"));
|
assertTrue(s.contains("JSON") && s.contains("valid") && s.contains("MultiLayerConfiguration"), s);
|
||||||
assertEquals("MultiLayerNetwork", vr6.getFormatType());
|
assertEquals("MultiLayerNetwork", vr6.getFormatType());
|
||||||
assertEquals(MultiLayerNetwork.class, vr6.getFormatClass());
|
assertEquals(MultiLayerNetwork.class, vr6.getFormatClass());
|
||||||
assertNotNull(vr6.getException());
|
assertNotNull(vr6.getException());
|
||||||
|
@ -209,7 +209,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
|
||||||
ValidationResult vr2 = DL4JModelValidator.validateComputationGraph(f2);
|
ValidationResult vr2 = DL4JModelValidator.validateComputationGraph(f2);
|
||||||
assertFalse(vr2.isValid());
|
assertFalse(vr2.isValid());
|
||||||
String s = vr2.getIssues().get(0);
|
String s = vr2.getIssues().get(0);
|
||||||
assertTrue(s, s.contains("zip") && s.contains("corrupt"));
|
assertTrue(s.contains("zip") && s.contains("corrupt"), s);
|
||||||
assertEquals("ComputationGraph", vr2.getFormatType());
|
assertEquals("ComputationGraph", vr2.getFormatType());
|
||||||
assertEquals(ComputationGraph.class, vr2.getFormatClass());
|
assertEquals(ComputationGraph.class, vr2.getFormatClass());
|
||||||
assertNotNull(vr2.getException());
|
assertNotNull(vr2.getException());
|
||||||
|
@ -226,7 +226,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
|
||||||
assertFalse(vr3.isValid());
|
assertFalse(vr3.isValid());
|
||||||
s = vr3.getIssues().get(0);
|
s = vr3.getIssues().get(0);
|
||||||
assertEquals(1, vr3.getIssues().size());
|
assertEquals(1, vr3.getIssues().size());
|
||||||
assertTrue(s, s.contains("missing") && s.contains("configuration"));
|
assertTrue(s.contains("missing") && s.contains("configuration"), s);
|
||||||
assertEquals("ComputationGraph", vr3.getFormatType());
|
assertEquals("ComputationGraph", vr3.getFormatType());
|
||||||
assertEquals(ComputationGraph.class, vr3.getFormatClass());
|
assertEquals(ComputationGraph.class, vr3.getFormatClass());
|
||||||
assertNull(vr3.getException());
|
assertNull(vr3.getException());
|
||||||
|
@ -244,7 +244,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
|
||||||
assertFalse(vr4.isValid());
|
assertFalse(vr4.isValid());
|
||||||
s = vr4.getIssues().get(0);
|
s = vr4.getIssues().get(0);
|
||||||
assertEquals(1, vr4.getIssues().size());
|
assertEquals(1, vr4.getIssues().size());
|
||||||
assertTrue(s, s.contains("missing") && s.contains("coefficients"));
|
assertTrue(s.contains("missing") && s.contains("coefficients"), s);
|
||||||
assertEquals("ComputationGraph", vr4.getFormatType());
|
assertEquals("ComputationGraph", vr4.getFormatType());
|
||||||
assertEquals(ComputationGraph.class, vr4.getFormatClass());
|
assertEquals(ComputationGraph.class, vr4.getFormatClass());
|
||||||
assertNull(vr4.getException());
|
assertNull(vr4.getException());
|
||||||
|
@ -287,7 +287,7 @@ public class ModelValidatorTests extends BaseDL4JTest {
|
||||||
assertFalse(vr6.isValid());
|
assertFalse(vr6.isValid());
|
||||||
s = vr6.getIssues().get(0);
|
s = vr6.getIssues().get(0);
|
||||||
assertEquals(1, vr6.getIssues().size());
|
assertEquals(1, vr6.getIssues().size());
|
||||||
assertTrue(s, s.contains("JSON") && s.contains("valid") && s.contains("ComputationGraphConfiguration"));
|
assertTrue(s.contains("JSON") && s.contains("valid") && s.contains("ComputationGraphConfiguration"), s);
|
||||||
assertEquals("ComputationGraph", vr6.getFormatType());
|
assertEquals("ComputationGraph", vr6.getFormatType());
|
||||||
assertEquals(ComputationGraph.class, vr6.getFormatClass());
|
assertEquals(ComputationGraph.class, vr6.getFormatClass());
|
||||||
assertNotNull(vr6.getException());
|
assertNotNull(vr6.getException());
|
||||||
|
|
|
@ -24,9 +24,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.core.util.UIDProvider;
|
import org.deeplearning4j.core.util.UIDProvider;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
|
||||||
|
|
||||||
public class TestUIDProvider extends BaseDL4JTest {
|
public class TestUIDProvider extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
|
@ -89,6 +89,10 @@
|
||||||
<groupId>org.junit.jupiter</groupId>
|
<groupId>org.junit.jupiter</groupId>
|
||||||
<artifactId>junit-jupiter-engine</artifactId>
|
<artifactId>junit-jupiter-engine</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-params</artifactId>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<groupId>org.deeplearning4j</groupId>
|
||||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||||
|
@ -120,113 +124,4 @@
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-native</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<inherited>true</inherited>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-native</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<configuration>
|
|
||||||
<environmentVariables>
|
|
||||||
|
|
||||||
</environmentVariables>
|
|
||||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>*.java</include>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
<include>**/Test*.java</include>
|
|
||||||
<include>**/*Test.java</include>
|
|
||||||
<include>**/*TestCase.java</include>
|
|
||||||
</includes>
|
|
||||||
<junitArtifactName>org.junit.jupiter:junit-jupiter</junitArtifactName>
|
|
||||||
<systemPropertyVariables>
|
|
||||||
<org.nd4j.linalg.defaultbackend>
|
|
||||||
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
||||||
</org.nd4j.linalg.defaultbackend>
|
|
||||||
<org.nd4j.linalg.tests.backendstorun>
|
|
||||||
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
||||||
</org.nd4j.linalg.tests.backendstorun>
|
|
||||||
</systemPropertyVariables>
|
|
||||||
<!--
|
|
||||||
Maximum heap size was set to 8g, as a minimum required value for tests run.
|
|
||||||
Depending on a build machine, default value is not always enough.
|
|
||||||
|
|
||||||
For testing large zoo models, this may not be enough (so comment it out).
|
|
||||||
-->
|
|
||||||
<argLine></argLine>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-cuda-11.0</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.maven.surefire</groupId>
|
|
||||||
<artifactId>surefire-junit47</artifactId>
|
|
||||||
<version>2.19.1</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<configuration>
|
|
||||||
<environmentVariables>
|
|
||||||
</environmentVariables>
|
|
||||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>*.java</include>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
<include>**/Test*.java</include>
|
|
||||||
<include>**/*Test.java</include>
|
|
||||||
<include>**/*TestCase.java</include>
|
|
||||||
</includes>
|
|
||||||
<junitArtifactName>org.junit.jupiter:junit-jupiter</junitArtifactName>
|
|
||||||
<systemPropertyVariables>
|
|
||||||
<org.nd4j.linalg.defaultbackend>
|
|
||||||
org.nd4j.linalg.jcublas.JCublasBackend
|
|
||||||
</org.nd4j.linalg.defaultbackend>
|
|
||||||
<org.nd4j.linalg.tests.backendstorun>
|
|
||||||
org.nd4j.linalg.jcublas.JCublasBackend
|
|
||||||
</org.nd4j.linalg.tests.backendstorun>
|
|
||||||
</systemPropertyVariables>
|
|
||||||
<!--
|
|
||||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
|
||||||
Depending on a build machine, default value is not always enough.
|
|
||||||
-->
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -107,17 +107,18 @@
|
||||||
</dependencyManagement>
|
</dependencyManagement>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
<!-- For unit tests -->
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.junit.jupiter</groupId>
|
<groupId>org.junit.jupiter</groupId>
|
||||||
<artifactId>junit-jupiter-api</artifactId>
|
<artifactId>junit-jupiter-api</artifactId>
|
||||||
<version>${junit.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.junit.jupiter</groupId>
|
<groupId>org.junit.jupiter</groupId>
|
||||||
<artifactId>junit-jupiter-engine</artifactId>
|
<artifactId>junit-jupiter-engine</artifactId>
|
||||||
<version>${junit.version}</version>
|
</dependency>
|
||||||
<scope>test</scope>
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-params</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.projectlombok</groupId>
|
<groupId>org.projectlombok</groupId>
|
||||||
|
@ -230,7 +231,6 @@
|
||||||
<plugins>
|
<plugins>
|
||||||
<plugin>
|
<plugin>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<version>${maven-surefire-plugin.version}</version>
|
|
||||||
<inherited>true</inherited>
|
<inherited>true</inherited>
|
||||||
<configuration>
|
<configuration>
|
||||||
<!--
|
<!--
|
||||||
|
@ -250,6 +250,13 @@
|
||||||
<include>**/*.java</include>
|
<include>**/*.java</include>
|
||||||
</includes>
|
</includes>
|
||||||
</configuration>
|
</configuration>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.maven.surefire</groupId>
|
||||||
|
<artifactId>surefire-junit-platform</artifactId>
|
||||||
|
<version>${maven-surefire-plugin.version}</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.eclipse.m2e</groupId>
|
<groupId>org.eclipse.m2e</groupId>
|
||||||
|
@ -316,6 +323,21 @@
|
||||||
<artifactId>nd4j-native</artifactId>
|
<artifactId>nd4j-native</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-engine</artifactId>
|
||||||
|
<version>${junit.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-params</artifactId>
|
||||||
|
<version>${junit.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.maven.surefire</groupId>
|
||||||
|
<artifactId>surefire-junit-platform</artifactId>
|
||||||
|
<version>${maven-surefire-plugin.version}</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
<configuration>
|
<configuration>
|
||||||
<environmentVariables>
|
<environmentVariables>
|
||||||
|
@ -344,7 +366,7 @@
|
||||||
|
|
||||||
For testing large zoo models, this may not be enough (so comment it out).
|
For testing large zoo models, this may not be enough (so comment it out).
|
||||||
-->
|
-->
|
||||||
<argLine> "</argLine>
|
<argLine></argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
@ -376,13 +398,7 @@
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<dependencies>
|
<version>${maven-surefire-plugin.version}</version>
|
||||||
<dependency>
|
|
||||||
<groupId>org.junit</groupId>
|
|
||||||
<artifactId>surefire-junit5</artifactId>
|
|
||||||
<version>5.0.0-ALPHA</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<configuration>
|
<configuration>
|
||||||
<environmentVariables>
|
<environmentVariables>
|
||||||
</environmentVariables>
|
</environmentVariables>
|
||||||
|
@ -409,6 +425,7 @@
|
||||||
-->
|
-->
|
||||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
|
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
</build>
|
</build>
|
||||||
|
|
|
@ -1259,7 +1259,7 @@ char* DoubleToBuffer(double value, char* buffer) {
|
||||||
// DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
|
// DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
|
||||||
// platforms these days. Just in case some system exists where DBL_DIG
|
// platforms these days. Just in case some system exists where DBL_DIG
|
||||||
// is significantly larger -- and risks overflowing our buffer -- we have
|
// is significantly larger -- and risks overflowing our buffer -- we have
|
||||||
// this assert.
|
// this
|
||||||
GOOGLE_COMPILE_ASSERT(DBL_DIG < 20, DBL_DIG_is_too_big);
|
GOOGLE_COMPILE_ASSERT(DBL_DIG < 20, DBL_DIG_is_too_big);
|
||||||
|
|
||||||
if (value == std::numeric_limits<double>::infinity()) {
|
if (value == std::numeric_limits<double>::infinity()) {
|
||||||
|
@ -1377,7 +1377,7 @@ char* FloatToBuffer(float value, char* buffer) {
|
||||||
// FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
|
// FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
|
||||||
// platforms these days. Just in case some system exists where FLT_DIG
|
// platforms these days. Just in case some system exists where FLT_DIG
|
||||||
// is significantly larger -- and risks overflowing our buffer -- we have
|
// is significantly larger -- and risks overflowing our buffer -- we have
|
||||||
// this assert.
|
// this
|
||||||
GOOGLE_COMPILE_ASSERT(FLT_DIG < 10, FLT_DIG_is_too_big);
|
GOOGLE_COMPILE_ASSERT(FLT_DIG < 10, FLT_DIG_is_too_big);
|
||||||
|
|
||||||
if (value == std::numeric_limits<double>::infinity()) {
|
if (value == std::numeric_limits<double>::infinity()) {
|
||||||
|
|
|
@ -182,7 +182,7 @@ size_t DoubleToBuffer(double value, char* buffer) {
|
||||||
// DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
|
// DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
|
||||||
// platforms these days. Just in case some system exists where DBL_DIG
|
// platforms these days. Just in case some system exists where DBL_DIG
|
||||||
// is significantly larger -- and risks overflowing our buffer -- we have
|
// is significantly larger -- and risks overflowing our buffer -- we have
|
||||||
// this assert.
|
// this
|
||||||
static_assert(DBL_DIG < 20, "DBL_DIG is too big");
|
static_assert(DBL_DIG < 20, "DBL_DIG is too big");
|
||||||
|
|
||||||
if (std::abs(value) <= kDoublePrecisionCheckMax) {
|
if (std::abs(value) <= kDoublePrecisionCheckMax) {
|
||||||
|
@ -363,7 +363,7 @@ size_t FloatToBuffer(float value, char* buffer) {
|
||||||
// FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
|
// FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
|
||||||
// platforms these days. Just in case some system exists where FLT_DIG
|
// platforms these days. Just in case some system exists where FLT_DIG
|
||||||
// is significantly larger -- and risks overflowing our buffer -- we have
|
// is significantly larger -- and risks overflowing our buffer -- we have
|
||||||
// this assert.
|
// this
|
||||||
static_assert(FLT_DIG < 10, "FLT_DIG is too big");
|
static_assert(FLT_DIG < 10, "FLT_DIG is too big");
|
||||||
|
|
||||||
int snprintf_result =
|
int snprintf_result =
|
||||||
|
|
|
@ -139,13 +139,6 @@
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.maven.surefire</groupId>
|
|
||||||
<artifactId>surefire-junit47</artifactId>
|
|
||||||
<version>2.19.1</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<configuration>
|
<configuration>
|
||||||
<environmentVariables>
|
<environmentVariables>
|
||||||
<LD_LIBRARY_PATH>
|
<LD_LIBRARY_PATH>
|
||||||
|
|
|
@ -81,7 +81,6 @@
|
||||||
|
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
|
|
||||||
<!-- Skip execution of Javadoc since some versions run out of memory -->
|
<!-- Skip execution of Javadoc since some versions run out of memory -->
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
@ -124,7 +123,6 @@
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<version>2.19.1</version>
|
|
||||||
<configuration>
|
<configuration>
|
||||||
<environmentVariables>
|
<environmentVariables>
|
||||||
<LD_LIBRARY_PATH>${env.LD_LIBRARY_PATH}:${user.dir}</LD_LIBRARY_PATH>
|
<LD_LIBRARY_PATH>${env.LD_LIBRARY_PATH}:${user.dir}</LD_LIBRARY_PATH>
|
||||||
|
|
|
@ -20,21 +20,19 @@
|
||||||
|
|
||||||
package org.nd4j;
|
package org.nd4j;
|
||||||
|
|
||||||
import org.junit.AfterClass;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.runner.RunWith;
|
|
||||||
import org.junit.runners.Suite;
|
|
||||||
import org.nd4j.autodiff.opvalidation.*;
|
import org.nd4j.autodiff.opvalidation.*;
|
||||||
import org.nd4j.autodiff.validation.OpValidation;
|
import org.nd4j.autodiff.validation.OpValidation;
|
||||||
//import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
|
//import org.nd4j.imports.tfgraphs.TFGraphTestAllSameDiff;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static org.junit.Assume.assumeFalse;
|
import static org.junit.jupiter.api.Assumptions.assumeFalse;
|
||||||
|
|
||||||
@RunWith(Suite.class)
|
|
||||||
@Suite.SuiteClasses({
|
/*@Suite.SuiteClasses({
|
||||||
//Note: these will be run as part of the suite only, and will NOT be run again separately
|
//Note: these will be run as part of the suite only, and will NOT be run again separately
|
||||||
LayerOpValidation.class,
|
LayerOpValidation.class,
|
||||||
LossOpValidation.class,
|
LossOpValidation.class,
|
||||||
|
@ -48,7 +46,7 @@ import static org.junit.Assume.assumeFalse;
|
||||||
//TF import tests
|
//TF import tests
|
||||||
//TFGraphTestAllSameDiff.class
|
//TFGraphTestAllSameDiff.class
|
||||||
//TFGraphTestAllLibnd4j.class
|
//TFGraphTestAllLibnd4j.class
|
||||||
})
|
})*/
|
||||||
//IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test"
|
//IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test"
|
||||||
// With it ignored here, the individual tests will run outside (i.e., separately/independently) of the suite in both "mvn test" and IntelliJ
|
// With it ignored here, the individual tests will run outside (i.e., separately/independently) of the suite in both "mvn test" and IntelliJ
|
||||||
@Disabled
|
@Disabled
|
||||||
|
@ -84,7 +82,7 @@ public class OpValidationSuite {
|
||||||
Nd4j.getRandom().setSeed(123);
|
Nd4j.getRandom().setSeed(123);
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterClass
|
@AfterEach
|
||||||
public static void afterClass() {
|
public static void afterClass() {
|
||||||
Nd4j.setDataType(initialType);
|
Nd4j.setDataType(initialType);
|
||||||
|
|
||||||
|
|
|
@ -145,9 +145,8 @@ public class TestOpMapping extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOpMappingCoverage() throws Exception {
|
public void testOpMappingCoverage() throws Exception {
|
||||||
Map<String, DifferentialFunction> opNameMapping = ImportClassMapping.getOpNameMapping();
|
Map<String, DifferentialFunction> opNameMapping = ImportClassMapping.getOpNameMapping();
|
||||||
Map<String, DifferentialFunction> tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions();
|
Map<String, DifferentialFunction> tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions();
|
||||||
|
@ -197,9 +196,8 @@ public class TestOpMapping extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOpsInNamespace(Nd4jBackend backend) throws Exception {
|
public void testOpsInNamespace(Nd4jBackend backend) throws Exception {
|
||||||
//Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't
|
//Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't
|
||||||
// want to add to a namespace for some reason)
|
// want to add to a namespace for some reason)
|
||||||
|
@ -361,7 +359,7 @@ public class TestOpMapping extends BaseNd4jTestWithBackends {
|
||||||
@Test
|
@Test
|
||||||
@Disabled
|
@Disabled
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void generateOpClassList(Nd4jBackend backend) throws Exception{
|
public void generateOpClassList(Nd4jBackend backend) throws Exception{
|
||||||
Reflections reflections = new Reflections("org.nd4j");
|
Reflections reflections = new Reflections("org.nd4j");
|
||||||
Set<Class<? extends DifferentialFunction>> subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
|
Set<Class<? extends DifferentialFunction>> subTypes = reflections.getSubTypesOf(DifferentialFunction.class);
|
||||||
|
|
|
@ -55,9 +55,8 @@ public class TestSessions extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInferenceSessionBasic(Nd4jBackend backend) {
|
public void testInferenceSessionBasic(Nd4jBackend backend) {
|
||||||
//So far: trivial test to check execution order
|
//So far: trivial test to check execution order
|
||||||
|
|
||||||
|
@ -89,9 +88,8 @@ public class TestSessions extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInferenceSessionBasic2(Nd4jBackend backend) {
|
public void testInferenceSessionBasic2(Nd4jBackend backend) {
|
||||||
//So far: trivial test to check execution order
|
//So far: trivial test to check execution order
|
||||||
|
|
||||||
|
@ -127,9 +125,8 @@ public class TestSessions extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(dExp, outMap.get("d"));
|
assertEquals(dExp, outMap.get("d"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMergeSimple(Nd4jBackend backend) {
|
public void testMergeSimple(Nd4jBackend backend) {
|
||||||
//This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available...
|
//This isn't really a sensible graph, as merge op behaviour is undefined when multiple inputs are available...
|
||||||
|
|
||||||
|
@ -165,9 +162,8 @@ public class TestSessions extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSwitchSimple(Nd4jBackend backend) {
|
public void testSwitchSimple(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
|
@ -34,7 +34,6 @@ import org.nd4j.common.primitives.Pair;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertNotNull;
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestDependencyTracker extends BaseNd4jTestWithBackends {
|
public class TestDependencyTracker extends BaseNd4jTestWithBackends {
|
||||||
|
@ -45,9 +44,8 @@ public class TestDependencyTracker extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testSimple(Nd4jBackend backend){
|
public void testSimple(Nd4jBackend backend){
|
||||||
|
|
||||||
DependencyTracker<String,String> dt = new DependencyTracker<>();
|
DependencyTracker<String,String> dt = new DependencyTracker<>();
|
||||||
|
@ -94,9 +92,8 @@ public class TestDependencyTracker extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(dt.isEmpty());
|
assertTrue(dt.isEmpty());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testSatisfiedBeforeAdd(Nd4jBackend backend){
|
public void testSatisfiedBeforeAdd(Nd4jBackend backend){
|
||||||
DependencyTracker<String,String> dt = new DependencyTracker<>();
|
DependencyTracker<String,String> dt = new DependencyTracker<>();
|
||||||
|
|
||||||
|
@ -135,9 +132,8 @@ public class TestDependencyTracker extends BaseNd4jTestWithBackends {
|
||||||
assertFalse(dt.hasNewAllSatisfied());
|
assertFalse(dt.hasNewAllSatisfied());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testMarkUnsatisfied(Nd4jBackend backend){
|
public void testMarkUnsatisfied(Nd4jBackend backend){
|
||||||
|
|
||||||
DependencyTracker<String,String> dt = new DependencyTracker<>();
|
DependencyTracker<String,String> dt = new DependencyTracker<>();
|
||||||
|
@ -169,9 +165,8 @@ public class TestDependencyTracker extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testIdentityDependencyTracker(){
|
public void testIdentityDependencyTracker(){
|
||||||
IdentityDependencyTracker<INDArray, String> dt = new IdentityDependencyTracker<>();
|
IdentityDependencyTracker<INDArray, String> dt = new IdentityDependencyTracker<>();
|
||||||
assertTrue(dt.isEmpty());
|
assertTrue(dt.isEmpty());
|
||||||
|
|
|
@ -41,9 +41,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
public class ActivationGradChecks extends BaseOpValidation {
|
public class ActivationGradChecks extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testActivationGradientCheck1(Nd4jBackend backend) {
|
public void testActivationGradientCheck1(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -61,9 +60,8 @@ public class ActivationGradChecks extends BaseOpValidation {
|
||||||
assertTrue(ok);
|
assertTrue(ok);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testActivationGradientCheck2(Nd4jBackend backend) {
|
public void testActivationGradientCheck2(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
|
@ -73,9 +73,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
return 90000L;
|
return 90000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testXwPlusB(Nd4jBackend backend) {
|
public void testXwPlusB(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -109,9 +108,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReluLayer(Nd4jBackend backend) {
|
public void testReluLayer(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -139,9 +137,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBiasAdd(Nd4jBackend backend) {
|
public void testBiasAdd(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -165,9 +162,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConv2d(Nd4jBackend backend) {
|
public void testConv2d(Nd4jBackend backend) {
|
||||||
//avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling
|
//avg pool, batch norm, conv2d, max pool 2d, pooling2d, upsampling
|
||||||
//Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d
|
//Tested elsewhere: deconv2d, depthwise2d, LRN, sconv2d
|
||||||
|
@ -307,9 +303,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLrn2d(Nd4jBackend backend) {
|
public void testLrn2d(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -350,9 +345,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIm2Col(Nd4jBackend backend) {
|
public void testIm2Col(Nd4jBackend backend) {
|
||||||
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873
|
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/eclipse/deeplearning4j/issues/6873
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -391,9 +385,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOutputShape(Nd4jBackend backend) {
|
public void testOutputShape(Nd4jBackend backend) {
|
||||||
long[] inSize = {1, 8, 8, 3};
|
long[] inSize = {1, 8, 8, 3};
|
||||||
|
|
||||||
|
@ -443,9 +436,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAvgPool(Nd4jBackend backend) {
|
public void testAvgPool(Nd4jBackend backend) {
|
||||||
long[] inSize = {1, 8, 8, 3}; //NHWC
|
long[] inSize = {1, 8, 8, 3}; //NHWC
|
||||||
|
|
||||||
|
@ -488,9 +480,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
return new int[]{in[0], in[2], in[3], in[4], in[1]};
|
return new int[]{in[0], in[2], in[3], in[4], in[1]};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConv3d(Nd4jBackend backend) {
|
public void testConv3d(Nd4jBackend backend) {
|
||||||
//Pooling3d, Conv3D, batch norm
|
//Pooling3d, Conv3D, batch norm
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -592,9 +583,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDepthWiseConv2dBasic(Nd4jBackend backend) {
|
public void testDepthWiseConv2dBasic(Nd4jBackend backend) {
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
int depthWise = 4;
|
int depthWise = 4;
|
||||||
|
@ -633,9 +623,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(new long[]{mb, depthWise * nIn, 27, 27}, outShape);
|
assertArrayEquals(new long[]{mb, depthWise * nIn, 27, 27}, outShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSeparableConv2dBasic(Nd4jBackend backend) {
|
public void testSeparableConv2dBasic(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int nIn = 2;
|
int nIn = 2;
|
||||||
|
@ -691,9 +680,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDeconv2dBasic(Nd4jBackend backend) {
|
public void testDeconv2dBasic(Nd4jBackend backend) {
|
||||||
int nIn = 2;
|
int nIn = 2;
|
||||||
int nOut = 3;
|
int nOut = 3;
|
||||||
|
@ -737,9 +725,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConv2dBasic(Nd4jBackend backend) {
|
public void testConv2dBasic(Nd4jBackend backend) {
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
int nOut = 4;
|
int nOut = 4;
|
||||||
|
@ -780,9 +767,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
// sd.execBackwards(); // TODO: test failing here
|
// sd.execBackwards(); // TODO: test failing here
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMaxPoolingArgMax(Nd4jBackend backend) {
|
public void testMaxPoolingArgMax(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
|
@ -811,9 +797,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(inArr.shape(), results[1].eval().shape());
|
assertArrayEquals(inArr.shape(), results[1].eval().shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMaxPooling2dBasic(Nd4jBackend backend) {
|
public void testMaxPooling2dBasic(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
|
@ -871,9 +856,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
return max;
|
return max;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAvgPooling2dBasic(Nd4jBackend backend) {
|
public void testAvgPooling2dBasic(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
|
@ -922,9 +906,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAvgPooling3dBasic(Nd4jBackend backend) {
|
public void testAvgPooling3dBasic(Nd4jBackend backend) {
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
int kH = 2;
|
int kH = 2;
|
||||||
|
@ -961,9 +944,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMaxPooling3dBasic(Nd4jBackend backend) {
|
public void testMaxPooling3dBasic(Nd4jBackend backend) {
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
int kH = 2;
|
int kH = 2;
|
||||||
|
@ -1001,9 +983,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConv1dBasic(Nd4jBackend backend) {
|
public void testConv1dBasic(Nd4jBackend backend) {
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
int nOut = 4;
|
int nOut = 4;
|
||||||
|
@ -1038,9 +1019,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConv1dCausal(Nd4jBackend backend) {
|
public void testConv1dCausal(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
|
@ -1089,9 +1069,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConv1dForward(Nd4jBackend backend) {
|
public void testConv1dForward(Nd4jBackend backend) {
|
||||||
int nIn = 2;
|
int nIn = 2;
|
||||||
int nOut = 1;
|
int nOut = 1;
|
||||||
|
@ -1134,9 +1113,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConv3dBasic(Nd4jBackend backend) {
|
public void testConv3dBasic(Nd4jBackend backend) {
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
int nOut = 4;
|
int nOut = 4;
|
||||||
|
@ -1182,9 +1160,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDeConv3dBasic(Nd4jBackend backend) {
|
public void testDeConv3dBasic(Nd4jBackend backend) {
|
||||||
int nIn = 4;
|
int nIn = 4;
|
||||||
int nOut = 3;
|
int nOut = 3;
|
||||||
|
@ -1229,9 +1206,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLayerNorm(Nd4jBackend backend) {
|
public void testLayerNorm(Nd4jBackend backend) {
|
||||||
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||||
final INDArray standardized = random.ulike();
|
final INDArray standardized = random.ulike();
|
||||||
|
@ -1256,9 +1232,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLayerNorm4d(Nd4jBackend backend) {
|
public void testLayerNorm4d(Nd4jBackend backend) {
|
||||||
int mb = 3;
|
int mb = 3;
|
||||||
int ch = 4;
|
int ch = 4;
|
||||||
|
@ -1290,9 +1265,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLayerNormOP(Nd4jBackend backend) {
|
public void testLayerNormOP(Nd4jBackend backend) {
|
||||||
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||||
final INDArray standardized = random.ulike();
|
final INDArray standardized = random.ulike();
|
||||||
|
@ -1308,9 +1282,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertEquals(res, output);
|
assertEquals(res, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLayerNormNoBias(Nd4jBackend backend) {
|
public void testLayerNormNoBias(Nd4jBackend backend) {
|
||||||
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||||
final INDArray standardized = random.ulike();
|
final INDArray standardized = random.ulike();
|
||||||
|
@ -1333,9 +1306,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLayerNormOPNoBias(Nd4jBackend backend) {
|
public void testLayerNormOPNoBias(Nd4jBackend backend) {
|
||||||
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||||
final INDArray standardized = random.ulike();
|
final INDArray standardized = random.ulike();
|
||||||
|
@ -1350,9 +1322,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertEquals(res, output);
|
assertEquals(res, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLayerNormNoDeviation(Nd4jBackend backend) {
|
public void testLayerNormNoDeviation(Nd4jBackend backend) {
|
||||||
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4);
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
|
@ -1467,9 +1438,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLayerNormMixedOrders(Nd4jBackend backend) {
|
public void testLayerNormMixedOrders(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
|
INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f');
|
||||||
|
@ -1516,9 +1486,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertEquals(outCC, outFC); //Fails here
|
assertEquals(outCC, outFC); //Fails here
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBiasAdd_nchw_nhwc(Nd4jBackend backend) {
|
public void testBiasAdd_nchw_nhwc(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -1549,9 +1518,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDepthwiseConv2D(){
|
public void testDepthwiseConv2D(){
|
||||||
|
|
||||||
int bS = 10;
|
int bS = 10;
|
||||||
|
@ -1589,9 +1557,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void LSTMLayerTestCase1(Nd4jBackend backend) {
|
public void LSTMLayerTestCase1(Nd4jBackend backend) {
|
||||||
|
|
||||||
int bS = 5;
|
int bS = 5;
|
||||||
|
@ -1666,9 +1633,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void LSTMLayerTestCase2(Nd4jBackend backend) {
|
public void LSTMLayerTestCase2(Nd4jBackend backend) {
|
||||||
int bS = 5;
|
int bS = 5;
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
|
@ -1726,9 +1692,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void LSTMLayerTestCase3(Nd4jBackend backend) {
|
public void LSTMLayerTestCase3(Nd4jBackend backend) {
|
||||||
int bS = 5;
|
int bS = 5;
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
|
@ -1789,9 +1754,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void GRUTestCase(Nd4jBackend backend) {
|
public void GRUTestCase(Nd4jBackend backend) {
|
||||||
int bS = 5;
|
int bS = 5;
|
||||||
int nIn = 4;
|
int nIn = 4;
|
||||||
|
|
|
@ -55,9 +55,8 @@ public class LossOpValidation extends BaseOpValidation {
|
||||||
// All tested Loss Ops have backprop at the moment 2019/01/30
|
// All tested Loss Ops have backprop at the moment 2019/01/30
|
||||||
public static final Set<String> NO_BP_YET = new HashSet<>();
|
public static final Set<String> NO_BP_YET = new HashSet<>();
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLoss2d(Nd4jBackend backend) {
|
public void testLoss2d(Nd4jBackend backend) {
|
||||||
final List<String> oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax");
|
final List<String> oneDimensionalOutputFns = Arrays.asList("cosine", "mpwse", "softmaxxent", "softmaxxent_smooth", "mpwse", "sparsesoftmax");
|
||||||
|
|
||||||
|
@ -369,9 +368,8 @@ public class LossOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCosineDistance(){
|
public void testCosineDistance(){
|
||||||
INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}});
|
INDArray arr = Nd4j.create(new double[][]{{-0.3, -0.2, -0.1}, {0, 0.1, 0.2}});
|
||||||
INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}});
|
INDArray label = Nd4j.create(new double[][]{{1.0, 2.0, 3.0}, {-1.0, 2.0, 1.0}});
|
||||||
|
@ -389,9 +387,8 @@ public class LossOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, out);
|
assertEquals(exp, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testL2Loss(){
|
public void testL2Loss(){
|
||||||
|
|
||||||
for( int rank=0; rank<=3; rank++ ){
|
for( int rank=0; rank<=3; rank++ ){
|
||||||
|
@ -433,9 +430,8 @@ public class LossOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNonZeroResult(Nd4jBackend backend) {
|
public void testNonZeroResult(Nd4jBackend backend) {
|
||||||
INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5);
|
INDArray predictions = Nd4j.rand(DataType.DOUBLE, 10, 5);
|
||||||
INDArray w = Nd4j.scalar(1.0);
|
INDArray w = Nd4j.scalar(1.0);
|
||||||
|
@ -493,9 +489,8 @@ public class LossOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void TestStdLossMixedDataType(){
|
public void TestStdLossMixedDataType(){
|
||||||
// Default Data Type in this test suite is Double.
|
// Default Data Type in this test suite is Double.
|
||||||
// This test used to throw an Exception that we have mixed data types.
|
// This test used to throw an Exception that we have mixed data types.
|
||||||
|
|
|
@ -74,17 +74,17 @@ import org.nd4j.common.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.Assume.assumeNotNull;
|
import static org.junit.jupiter.api.Assumptions.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class MiscOpValidation extends BaseOpValidation {
|
public class MiscOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGradientAutoBroadcast1(Nd4jBackend backend) {
|
public void testGradientAutoBroadcast1(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -171,9 +171,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),"Failed: " + failed);
|
assertEquals(0, failed.size(),"Failed: " + failed);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGradientAutoBroadcast2(Nd4jBackend backend) {
|
public void testGradientAutoBroadcast2(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -262,9 +261,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),"Failed: " + failed);
|
assertEquals(0, failed.size(),"Failed: " + failed);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGradientAutoBroadcast3(Nd4jBackend backend) {
|
public void testGradientAutoBroadcast3(Nd4jBackend backend) {
|
||||||
//These tests: output size > input sizes
|
//These tests: output size > input sizes
|
||||||
|
|
||||||
|
@ -372,9 +370,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
return Long.MAX_VALUE;
|
return Long.MAX_VALUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScatterOpGradients(Nd4jBackend backend) {
|
public void testScatterOpGradients(Nd4jBackend backend) {
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -476,9 +473,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScatterUpdate(){
|
public void testScatterUpdate(){
|
||||||
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3);
|
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3);
|
||||||
INDArray updates = Nd4j.create(new float[][]{
|
INDArray updates = Nd4j.create(new float[][]{
|
||||||
|
@ -499,9 +495,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, out);
|
assertEquals(exp, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGatherGradient(Nd4jBackend backend) {
|
public void testGatherGradient(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -552,9 +547,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTrace(){
|
public void testTrace(){
|
||||||
//TODO need to work out how to handle shape_op for scalars...
|
//TODO need to work out how to handle shape_op for scalars...
|
||||||
//OpValidationSuite.ignoreFailing();
|
//OpValidationSuite.ignoreFailing();
|
||||||
|
@ -579,9 +573,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTensorGradTensorMmul(Nd4jBackend backend) {
|
public void testTensorGradTensorMmul(Nd4jBackend backend) {
|
||||||
OpValidationSuite.ignoreFailing();
|
OpValidationSuite.ignoreFailing();
|
||||||
|
|
||||||
|
@ -603,9 +596,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMulGradient(Nd4jBackend backend) {
|
public void testMulGradient(Nd4jBackend backend) {
|
||||||
INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray arr1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray arr2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
|
@ -670,9 +662,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMmulGradientManual(Nd4jBackend backend) {
|
public void testMmulGradientManual(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
|
@ -689,14 +680,14 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}, inputs);
|
}, inputs);
|
||||||
|
|
||||||
|
|
||||||
assumeNotNull(sameDiff.getFunction("mmulGradient").getFunction("grad"));
|
assertNotNull(sameDiff.getFunction("mmulGradient").getFunction("grad"));
|
||||||
assumeNotNull(sameDiff.getFunction("mmulGradient").grad("x"));
|
assertNotNull(sameDiff.getFunction("mmulGradient").grad("x"));
|
||||||
assumeNotNull(sameDiff.getFunction("mmulGradient").grad("y"));
|
assertNotNull(sameDiff.getFunction("mmulGradient").grad("y"));
|
||||||
|
|
||||||
SDVariable gradWrtX = sameDiff.getFunction("mmulGradient").grad("x");
|
SDVariable gradWrtX = sameDiff.getFunction("mmulGradient").grad("x");
|
||||||
SDVariable gradWrtY = sameDiff.getFunction("mmulGradient").grad("y");
|
SDVariable gradWrtY = sameDiff.getFunction("mmulGradient").grad("y");
|
||||||
assumeNotNull(gradWrtX.getArr());
|
assertNotNull(gradWrtX.getArr());
|
||||||
assumeNotNull(gradWrtY.getArr());
|
assertNotNull(gradWrtY.getArr());
|
||||||
|
|
||||||
|
|
||||||
INDArray xGradAssertion = Nd4j.create(new double[][]{
|
INDArray xGradAssertion = Nd4j.create(new double[][]{
|
||||||
|
@ -713,9 +704,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(yGradAssertion, gradWrtY.getArr());
|
assertEquals(yGradAssertion, gradWrtY.getArr());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMmulGradients(){
|
public void testMmulGradients(){
|
||||||
int[] aShape = new int[]{2,3};
|
int[] aShape = new int[]{2,3};
|
||||||
int[] bShape = new int[]{3,4};
|
int[] bShape = new int[]{3,4};
|
||||||
|
@ -766,9 +756,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
return new int[]{orig[1], orig[0]};
|
return new int[]{orig[1], orig[0]};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBatchMmulBasic(Nd4jBackend backend) {
|
public void testBatchMmulBasic(Nd4jBackend backend) {
|
||||||
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873
|
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6873
|
||||||
int M = 5;
|
int M = 5;
|
||||||
|
@ -793,9 +782,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMmulWithTranspose(Nd4jBackend backend) {
|
public void testMmulWithTranspose(Nd4jBackend backend) {
|
||||||
|
|
||||||
//Here: [x,3]^T * [x,4] = [3,4]
|
//Here: [x,3]^T * [x,4] = [3,4]
|
||||||
|
@ -832,9 +820,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMmulOutputSizeCalculation(){
|
public void testMmulOutputSizeCalculation(){
|
||||||
//[3,2] x [2,4] with result transpose: output shape [4,3]
|
//[3,2] x [2,4] with result transpose: output shape [4,3]
|
||||||
INDArray a = Nd4j.create(3,2);
|
INDArray a = Nd4j.create(3,2);
|
||||||
|
@ -866,9 +853,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testFillOp(){
|
public void testFillOp(){
|
||||||
|
|
||||||
INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT);
|
INDArray ia = Nd4j.createFromArray(new double[]{2,2}).castTo(DataType.INT);
|
||||||
|
@ -882,9 +868,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testClipByNorm(){
|
public void testClipByNorm(){
|
||||||
//Expected: if array.norm2(1) is less than 1.0, not modified
|
//Expected: if array.norm2(1) is less than 1.0, not modified
|
||||||
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
||||||
|
@ -916,9 +901,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, norm2_1b);
|
assertEquals(exp, norm2_1b);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testClipByNorm2(){
|
public void testClipByNorm2(){
|
||||||
//Expected: if array.norm2(1) is less than 1.0, not modified
|
//Expected: if array.norm2(1) is less than 1.0, not modified
|
||||||
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
||||||
|
@ -961,9 +945,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testClipByNorm1(){
|
public void testClipByNorm1(){
|
||||||
//Expected: if array.norm2(1) is less than 1.0, not modified
|
//Expected: if array.norm2(1) is less than 1.0, not modified
|
||||||
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
||||||
|
@ -1003,9 +986,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testClipByNorm0(){
|
public void testClipByNorm0(){
|
||||||
//Expected: if array.norm2(0) is less than 1.0, not modified
|
//Expected: if array.norm2(0) is less than 1.0, not modified
|
||||||
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
//Otherwise: array.tad(x,1) = array.tad(x,1) * 1.0 / array.tad(x,1).norm2()
|
||||||
|
@ -1034,9 +1016,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(OpValidation.validate(op));
|
assertNull(OpValidation.validate(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCumSum(){
|
public void testCumSum(){
|
||||||
|
|
||||||
List<String> failing = new ArrayList<>();
|
List<String> failing = new ArrayList<>();
|
||||||
|
@ -1101,9 +1082,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCumProd(){
|
public void testCumProd(){
|
||||||
List<String> failing = new ArrayList<>();
|
List<String> failing = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -1171,9 +1151,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failing.size(),failing.toString());
|
assertEquals(0, failing.size(),failing.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOneHot1(){
|
public void testOneHot1(){
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -1203,9 +1182,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals( 0, failed.size(),failed.toString());
|
assertEquals( 0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOneHotOp(){
|
public void testOneHotOp(){
|
||||||
//https://www.tensorflow.org/api_docs/python/tf/one_hot
|
//https://www.tensorflow.org/api_docs/python/tf/one_hot
|
||||||
//https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp
|
//https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp
|
||||||
|
@ -1219,9 +1197,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOneHot2(Nd4jBackend backend) {
|
public void testOneHot2(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
|
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
|
||||||
|
@ -1241,9 +1218,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOneHot4(Nd4jBackend backend) {
|
public void testOneHot4(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
|
INDArray indicesArr = Nd4j.createFromArray(0, 2, -1, 1);
|
||||||
|
@ -1263,9 +1239,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOneHot3(Nd4jBackend backend) {
|
public void testOneHot3(Nd4jBackend backend) {
|
||||||
//https://github.com/deeplearning4j/deeplearning4j/issues/6872
|
//https://github.com/deeplearning4j/deeplearning4j/issues/6872
|
||||||
|
|
||||||
|
@ -1300,9 +1275,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLinspace(){
|
public void testLinspace(){
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10);
|
SDVariable out = sd.linspace("linspace", DataType.DOUBLE, 1,10,10);
|
||||||
|
@ -1315,9 +1289,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLinspace2(){
|
public void testLinspace2(){
|
||||||
OpValidationSuite.ignoreFailing(); //TODO 2019/01/18
|
OpValidationSuite.ignoreFailing(); //TODO 2019/01/18
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -1331,9 +1304,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testShapeFn(Nd4jBackend backend) {
|
public void testShapeFn(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray in = Nd4j.create(new long[]{1, 2});
|
INDArray in = Nd4j.create(new long[]{1, 2});
|
||||||
|
@ -1347,9 +1319,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(new long[]{2}, shapes.get(0).getShape());
|
assertArrayEquals(new long[]{2}, shapes.get(0).getShape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testShapeFn2(Nd4jBackend backend) {
|
public void testShapeFn2(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray i = Nd4j.create(1,3);
|
INDArray i = Nd4j.create(1,3);
|
||||||
|
@ -1362,9 +1333,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMergeRank1(){
|
public void testMergeRank1(){
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5));
|
SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5));
|
||||||
|
@ -1382,9 +1352,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(1, inGrad.rank());
|
assertEquals(1, inGrad.rank());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDiagPart(Nd4jBackend backend) {
|
public void testDiagPart(Nd4jBackend backend) {
|
||||||
INDArray i = Nd4j.create(5,5);
|
INDArray i = Nd4j.create(5,5);
|
||||||
|
|
||||||
|
@ -1396,9 +1365,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(1, out.rank());
|
assertEquals(1, out.rank());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDiagShapeFn(Nd4jBackend backend) {
|
public void testDiagShapeFn(Nd4jBackend backend) {
|
||||||
INDArray i = Nd4j.create(5,5);
|
INDArray i = Nd4j.create(5,5);
|
||||||
|
|
||||||
|
@ -1411,9 +1379,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testZerosOnesLike(){
|
public void testZerosOnesLike(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -1455,9 +1422,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testZerosLikeOp(){
|
public void testZerosLikeOp(){
|
||||||
|
|
||||||
INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0);
|
INDArray arr = Nd4j.scalar(DataType.DOUBLE, 1.0);
|
||||||
|
@ -1472,9 +1438,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConfusionMatrix(){
|
public void testConfusionMatrix(){
|
||||||
DataType dt = DataType.DOUBLE;
|
DataType dt = DataType.DOUBLE;
|
||||||
|
|
||||||
|
@ -1510,9 +1475,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIsNonDecreasingIsStrictlyIncr(){
|
public void testIsNonDecreasingIsStrictlyIncr(){
|
||||||
List<long[]> shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3});
|
List<long[]> shapes = Arrays.asList(null, new long[]{12}, new long[]{1,12}, new long[]{3,4}, new long[]{2,2,3});
|
||||||
|
|
||||||
|
@ -1575,9 +1539,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals( 0, failed.size(),failed.toString());
|
assertEquals( 0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testExtractImagePatches(){
|
public void testExtractImagePatches(){
|
||||||
/*
|
/*
|
||||||
tf.reset_default_graph()
|
tf.reset_default_graph()
|
||||||
|
@ -1624,9 +1587,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, out);
|
assertEquals(exp, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSegmentProdBpSimple(){
|
public void testSegmentProdBpSimple(){
|
||||||
|
|
||||||
INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT);
|
INDArray segmentIdxs = Nd4j.create(new double[]{0,0,0,1,2,2,3,3}, new long[]{8}).castTo(DataType.INT);
|
||||||
|
@ -1646,9 +1608,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
Nd4j.getExecutioner().exec(op);
|
Nd4j.getExecutioner().exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMmulRank4() throws Exception {
|
public void testMmulRank4() throws Exception {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -1683,9 +1644,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(outExp, out);
|
assertEquals(outExp, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMmulRank4_simple(){
|
public void testMmulRank4_simple(){
|
||||||
|
|
||||||
INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
|
INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
|
||||||
|
@ -1711,9 +1671,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, out);
|
assertEquals(exp, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNthElementRank1(){
|
public void testNthElementRank1(){
|
||||||
INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9});
|
INDArray in = Nd4j.createFromArray(new double[]{0,1,2,3,4,5,6,7,8,9});
|
||||||
INDArray n = Nd4j.scalar(0);
|
INDArray n = Nd4j.scalar(0);
|
||||||
|
@ -1735,9 +1694,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0.0, out.getDouble(0), 1e-5);
|
assertEquals(0.0, out.getDouble(0), 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTensorMmulShape(){
|
public void testTensorMmulShape(){
|
||||||
INDArray a = Nd4j.create(new double[]{2}).reshape(1);
|
INDArray a = Nd4j.create(new double[]{2}).reshape(1);
|
||||||
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
|
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
|
||||||
|
@ -1755,9 +1713,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(new long[]{2,2}, l.get(0).getShape()); //Returning [1,2,2]
|
assertArrayEquals(new long[]{2,2}, l.get(0).getShape()); //Returning [1,2,2]
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTensorMmulShape2(){
|
public void testTensorMmulShape2(){
|
||||||
INDArray a = Nd4j.create(new double[]{2}).reshape(1);
|
INDArray a = Nd4j.create(new double[]{2}).reshape(1);
|
||||||
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
|
INDArray b = Nd4j.create(new double[]{1, 2, 3, 4}).reshape(2, 1, 2);
|
||||||
|
@ -1765,9 +1722,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(new long[]{2,2}, c.shape());
|
assertArrayEquals(new long[]{2,2}, c.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStopGradient(){
|
public void testStopGradient(){
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -1786,9 +1742,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(Nd4j.zeros(DataType.DOUBLE, 3, 4), wArr);
|
assertEquals(Nd4j.zeros(DataType.DOUBLE, 3, 4), wArr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckNumerics(){
|
public void testCheckNumerics(){
|
||||||
OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927
|
OpValidationSuite.ignoreFailing(); //https://github.com/eclipse/deeplearning4j/issues/7927
|
||||||
|
|
||||||
|
@ -1831,9 +1786,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
sd.outputAll(Collections.singletonMap("in", in));
|
sd.outputAll(Collections.singletonMap("in", in));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckNumerics2(Nd4jBackend backend) {
|
public void testCheckNumerics2(Nd4jBackend backend) {
|
||||||
INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4);
|
INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4);
|
||||||
INDArray msg = Nd4j.scalar("My error message!");
|
INDArray msg = Nd4j.scalar("My error message!");
|
||||||
|
@ -1846,9 +1800,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
Nd4j.getExecutioner().exec(op);
|
Nd4j.getExecutioner().exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testHistogramFixedWidth(){
|
public void testHistogramFixedWidth(){
|
||||||
//Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf]
|
//Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf]
|
||||||
INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9);
|
INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9);
|
||||||
|
@ -1866,9 +1819,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, out);
|
assertEquals(exp, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDynamicPartition(){
|
public void testDynamicPartition(){
|
||||||
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
|
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
|
||||||
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
|
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
|
||||||
|
@ -1886,9 +1838,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp2, out[2]);
|
assertEquals(exp2, out[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testListDiff(){
|
public void testListDiff(){
|
||||||
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
|
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
|
||||||
INDArray y = Nd4j.createFromArray(3, 1);
|
INDArray y = Nd4j.createFromArray(3, 1);
|
||||||
|
@ -1907,9 +1858,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, outIdx); //Indices of the values in x not in y
|
assertEquals(exp, outIdx); //Indices of the values in x not in y
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDivideNoNan(Nd4jBackend backend) {
|
public void testDivideNoNan(Nd4jBackend backend) {
|
||||||
OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff()
|
OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff()
|
||||||
|
|
||||||
|
@ -1933,9 +1883,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDigamma(Nd4jBackend backend) {
|
public void testDigamma(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||||
|
@ -1950,9 +1899,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testFlatten(Nd4jBackend backend) {
|
public void testFlatten(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -1974,9 +1922,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testFusedBatchNorm(Nd4jBackend backend) {
|
public void testFusedBatchNorm(Nd4jBackend backend) {
|
||||||
OpValidationSuite.ignoreFailing();
|
OpValidationSuite.ignoreFailing();
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -2021,9 +1968,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIgamma(Nd4jBackend backend) {
|
public void testIgamma(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||||
|
@ -2039,9 +1985,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIgammaC(Nd4jBackend backend) {
|
public void testIgammaC(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||||
|
@ -2058,9 +2003,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLgamma(Nd4jBackend backend) {
|
public void testLgamma(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -2085,9 +2029,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLu(Nd4jBackend backend) {
|
public void testLu(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -2118,9 +2061,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMatrixBandPart(Nd4jBackend backend) {
|
public void testMatrixBandPart(Nd4jBackend backend) {
|
||||||
OpValidationSuite.ignoreFailing();
|
OpValidationSuite.ignoreFailing();
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -2150,9 +2092,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPolygamma(Nd4jBackend backend) {
|
public void testPolygamma(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||||
|
@ -2168,9 +2109,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTriangularSolve(Nd4jBackend backend) {
|
public void testTriangularSolve(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray a = Nd4j.createFromArray(new float[]{
|
INDArray a = Nd4j.createFromArray(new float[]{
|
||||||
|
@ -2194,9 +2134,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBiasAdd(Nd4jBackend backend) {
|
public void testBiasAdd(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -2225,9 +2164,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBiasAddGrad(Nd4jBackend backend) {
|
public void testBiasAddGrad(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -2247,9 +2185,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRoll(Nd4jBackend backend) {
|
public void testRoll(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
|
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
|
||||||
|
@ -2269,9 +2206,8 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSeqMask(){
|
public void testSeqMask(){
|
||||||
INDArray arr = Nd4j.createFromArray(1,2,3);
|
INDArray arr = Nd4j.createFromArray(1,2,3);
|
||||||
INDArray maxLen = Nd4j.scalar(4);
|
INDArray maxLen = Nd4j.scalar(4);
|
||||||
|
|
|
@ -54,9 +54,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
public class RandomOpValidation extends BaseOpValidation {
|
public class RandomOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRandomOpsSDVarShape(Nd4jBackend backend) {
|
public void testRandomOpsSDVarShape(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
@ -157,9 +156,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRandomOpsLongShape(Nd4jBackend backend) {
|
public void testRandomOpsLongShape(Nd4jBackend backend) {
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -285,9 +283,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRandomBinomial(){
|
public void testRandomBinomial(){
|
||||||
|
|
||||||
INDArray z = Nd4j.create(new long[]{10});
|
INDArray z = Nd4j.create(new long[]{10});
|
||||||
|
@ -297,9 +294,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
System.out.println(z);
|
System.out.println(z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUniformRankSimple(Nd4jBackend backend) {
|
public void testUniformRankSimple(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray arr = Nd4j.createFromArray(new double[]{100.0});
|
INDArray arr = Nd4j.createFromArray(new double[]{100.0});
|
||||||
|
@ -331,9 +327,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRandomExponential(Nd4jBackend backend) {
|
public void testRandomExponential(Nd4jBackend backend) {
|
||||||
long length = 1_000_000;
|
long length = 1_000_000;
|
||||||
INDArray shape = Nd4j.createFromArray(new double[]{length});
|
INDArray shape = Nd4j.createFromArray(new double[]{length});
|
||||||
|
@ -355,9 +350,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
assertEquals( expStd, std, 0.1,"std");
|
assertEquals( expStd, std, 0.1,"std");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRange(){
|
public void testRange(){
|
||||||
//Technically deterministic, not random...
|
//Technically deterministic, not random...
|
||||||
|
|
||||||
|
@ -390,9 +384,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAllEmptyReduce(){
|
public void testAllEmptyReduce(){
|
||||||
INDArray x = Nd4j.createFromArray(true, true, true);
|
INDArray x = Nd4j.createFromArray(true, true, true);
|
||||||
All all = new All(x);
|
All all = new All(x);
|
||||||
|
@ -401,9 +394,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
assertEquals(x, out);
|
assertEquals(x, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUniformDtype(){
|
public void testUniformDtype(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
|
for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){
|
||||||
|
@ -431,9 +423,8 @@ public class RandomOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRandomExponential2(){
|
public void testRandomExponential2(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
DynamicCustomOp op = DynamicCustomOp.builder("random_exponential")
|
DynamicCustomOp op = DynamicCustomOp.builder("random_exponential")
|
||||||
|
|
|
@ -75,9 +75,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReduceSumBP(Nd4jBackend backend) {
|
public void testReduceSumBP(Nd4jBackend backend) {
|
||||||
//Full array reduction
|
//Full array reduction
|
||||||
|
|
||||||
|
@ -103,9 +102,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReduceSumAlongDim0BP(Nd4jBackend backend) {
|
public void testReduceSumAlongDim0BP(Nd4jBackend backend) {
|
||||||
//Reduction along dimension
|
//Reduction along dimension
|
||||||
//Inputs/outputs as before - but note that the output is no longer a scalar
|
//Inputs/outputs as before - but note that the output is no longer a scalar
|
||||||
|
@ -131,9 +129,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReduceSumAlongDim1BP(Nd4jBackend backend) {
|
public void testReduceSumAlongDim1BP(Nd4jBackend backend) {
|
||||||
//Reduction along dimension
|
//Reduction along dimension
|
||||||
//Inputs/outputs as before - but note that the output is no longer a scalar
|
//Inputs/outputs as before - but note that the output is no longer a scalar
|
||||||
|
@ -161,9 +158,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMeanBP(Nd4jBackend backend) {
|
public void testMeanBP(Nd4jBackend backend) {
|
||||||
|
|
||||||
//dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j))
|
//dL/dIn_i = dL/dOut * dOut/dIn_i = dL/dOut * (1/N * sum_j (in_j))
|
||||||
|
@ -194,9 +190,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMeanBP_Rank1(Nd4jBackend backend) {
|
public void testMeanBP_Rank1(Nd4jBackend backend) {
|
||||||
INDArray dLdOut = Nd4j.scalar(0.5);
|
INDArray dLdOut = Nd4j.scalar(0.5);
|
||||||
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
|
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
|
||||||
|
@ -209,9 +204,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMeanAlongDim0BP(Nd4jBackend backend) {
|
public void testMeanAlongDim0BP(Nd4jBackend backend) {
|
||||||
//Reduction along dimension
|
//Reduction along dimension
|
||||||
//Inputs/outputs as before - but note that the output is no longer a scalar
|
//Inputs/outputs as before - but note that the output is no longer a scalar
|
||||||
|
@ -239,9 +233,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMeanAlongDim1BP(Nd4jBackend backend) {
|
public void testMeanAlongDim1BP(Nd4jBackend backend) {
|
||||||
//Reduction along dimension
|
//Reduction along dimension
|
||||||
//Inputs/outputs as before - but note that the output is no longer a scalar
|
//Inputs/outputs as before - but note that the output is no longer a scalar
|
||||||
|
@ -269,9 +262,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMinBP(Nd4jBackend backend) {
|
public void testMinBP(Nd4jBackend backend) {
|
||||||
//Full array min reduction
|
//Full array min reduction
|
||||||
|
|
||||||
|
@ -310,9 +302,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMinAlongDimensionBP(Nd4jBackend backend) {
|
public void testMinAlongDimensionBP(Nd4jBackend backend) {
|
||||||
//Full array min reduction
|
//Full array min reduction
|
||||||
|
|
||||||
|
@ -355,9 +346,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMaxBP(Nd4jBackend backend) {
|
public void testMaxBP(Nd4jBackend backend) {
|
||||||
//Full array max reduction
|
//Full array max reduction
|
||||||
|
|
||||||
|
@ -387,9 +377,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMaxAlongDimensionBP(Nd4jBackend backend) {
|
public void testMaxAlongDimensionBP(Nd4jBackend backend) {
|
||||||
//Full array min reduction
|
//Full array min reduction
|
||||||
|
|
||||||
|
@ -432,9 +421,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testProdBP(Nd4jBackend backend) {
|
public void testProdBP(Nd4jBackend backend) {
|
||||||
//Full array product reduction
|
//Full array product reduction
|
||||||
|
|
||||||
|
@ -463,9 +451,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testProdAlongDimensionBP(Nd4jBackend backend) {
|
public void testProdAlongDimensionBP(Nd4jBackend backend) {
|
||||||
//dL/dIn_i = dL/dOut * dOut/dIn_i
|
//dL/dIn_i = dL/dOut * dOut/dIn_i
|
||||||
// = dL/dOut * d(prod(in))/dIn_i
|
// = dL/dOut * d(prod(in))/dIn_i
|
||||||
|
@ -521,9 +508,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStdevBP(Nd4jBackend backend) {
|
public void testStdevBP(Nd4jBackend backend) {
|
||||||
//If out = stdev(in) then:
|
//If out = stdev(in) then:
|
||||||
//dL/dIn = dL/dOut * dOut/dIn
|
//dL/dIn = dL/dOut * dOut/dIn
|
||||||
|
@ -559,9 +545,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStdevBP_Rank1(Nd4jBackend backend) {
|
public void testStdevBP_Rank1(Nd4jBackend backend) {
|
||||||
INDArray dLdOut = Nd4j.scalar(0.5);
|
INDArray dLdOut = Nd4j.scalar(0.5);
|
||||||
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
|
INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3});
|
||||||
|
@ -582,9 +567,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStdevAlongDimensionBP(Nd4jBackend backend) {
|
public void testStdevAlongDimensionBP(Nd4jBackend backend) {
|
||||||
//If out = stdev(in) then:
|
//If out = stdev(in) then:
|
||||||
//dL/dIn = dL/dOut * dOut/dIn
|
//dL/dIn = dL/dOut * dOut/dIn
|
||||||
|
@ -629,9 +613,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVarianceBP(Nd4jBackend backend) {
|
public void testVarianceBP(Nd4jBackend backend) {
|
||||||
//If out = variance(in) then:
|
//If out = variance(in) then:
|
||||||
//dL/dIn = dL/dOut * dOut/dIn
|
//dL/dIn = dL/dOut * dOut/dIn
|
||||||
|
@ -667,9 +650,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVarianceAlongDimensionBP(Nd4jBackend backend) {
|
public void testVarianceAlongDimensionBP(Nd4jBackend backend) {
|
||||||
//If out = variance(in) then:
|
//If out = variance(in) then:
|
||||||
//dL/dIn = dL/dOut * dOut/dIn
|
//dL/dIn = dL/dOut * dOut/dIn
|
||||||
|
@ -711,9 +693,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCumSumBP(Nd4jBackend backend) {
|
public void testCumSumBP(Nd4jBackend backend) {
|
||||||
//Standard case, non-reverse, non-exclusive
|
//Standard case, non-reverse, non-exclusive
|
||||||
//dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i
|
//dL/dIn_i = sum_j dL/dOut_j * dOut_j/dIn_i
|
||||||
|
@ -783,9 +764,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNorm2Bp(Nd4jBackend backend) {
|
public void testNorm2Bp(Nd4jBackend backend) {
|
||||||
//dL/dIn = dL/dOut * dOut/dIn
|
//dL/dIn = dL/dOut * dOut/dIn
|
||||||
// = dL/dOut * x/|x|_2
|
// = dL/dOut * x/|x|_2
|
||||||
|
@ -812,9 +792,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNorm2AlongDimensionBP(Nd4jBackend backend) {
|
public void testNorm2AlongDimensionBP(Nd4jBackend backend) {
|
||||||
//dL/dIn = dL/dOut * dOut/dIn
|
//dL/dIn = dL/dOut * dOut/dIn
|
||||||
// = dL/dOut * x/|x|_2
|
// = dL/dOut * x/|x|_2
|
||||||
|
@ -847,9 +826,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNorm1Bp(Nd4jBackend backend) {
|
public void testNorm1Bp(Nd4jBackend backend) {
|
||||||
//dL/dIn = dL/dOut * dOut/dIn
|
//dL/dIn = dL/dOut * dOut/dIn
|
||||||
// = dL/dOut * sgn(in)
|
// = dL/dOut * sgn(in)
|
||||||
|
@ -876,9 +854,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNorm1AlongDimensionBP(Nd4jBackend backend) {
|
public void testNorm1AlongDimensionBP(Nd4jBackend backend) {
|
||||||
//dL/dIn = dL/dOut * dOut/dIn
|
//dL/dIn = dL/dOut * dOut/dIn
|
||||||
// = dL/dOut * sgn(in)
|
// = dL/dOut * sgn(in)
|
||||||
|
@ -910,9 +887,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNormMaxBp(Nd4jBackend backend) {
|
public void testNormMaxBp(Nd4jBackend backend) {
|
||||||
//out = max_i (|in_i|)
|
//out = max_i (|in_i|)
|
||||||
//dL/dIn = dL/dOut * dOut/dIn
|
//dL/dIn = dL/dOut * dOut/dIn
|
||||||
|
@ -942,9 +918,8 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNormMaxAlongDimensionBP(Nd4jBackend backend) {
|
public void testNormMaxAlongDimensionBP(Nd4jBackend backend) {
|
||||||
//out = max_i (|in_i|)
|
//out = max_i (|in_i|)
|
||||||
//dL/dIn = dL/dOut * dOut/dIn
|
//dL/dIn = dL/dOut * dOut/dIn
|
||||||
|
|
|
@ -80,9 +80,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class ReductionOpValidation extends BaseOpValidation {
|
public class ReductionOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStdev(Nd4jBackend backend) {
|
public void testStdev(Nd4jBackend backend) {
|
||||||
List<String> errors = new ArrayList<>();
|
List<String> errors = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -108,9 +107,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, errors.size(),errors.toString());
|
assertEquals(0, errors.size(),errors.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testZeroCount(Nd4jBackend backend) {
|
public void testZeroCount(Nd4jBackend backend) {
|
||||||
List<String> allFailed = new ArrayList<>();
|
List<String> allFailed = new ArrayList<>();
|
||||||
for (int i = 0; i < 21; i++) {
|
for (int i = 0; i < 21; i++) {
|
||||||
|
@ -144,9 +142,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testZeroFraction(Nd4jBackend backend) {
|
public void testZeroFraction(Nd4jBackend backend) {
|
||||||
List<String> allFailed = new ArrayList<>();
|
List<String> allFailed = new ArrayList<>();
|
||||||
for (int i = 0; i < 2; i++) {
|
for (int i = 0; i < 2; i++) {
|
||||||
|
@ -176,9 +173,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, allFailed.size(),allFailed.toString());
|
assertEquals(0, allFailed.size(),allFailed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReductionGradientsSimple(Nd4jBackend backend) {
|
public void testReductionGradientsSimple(Nd4jBackend backend) {
|
||||||
//OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
|
//OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
|
||||||
//Test reductions: final and only function
|
//Test reductions: final and only function
|
||||||
|
@ -347,9 +343,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReductionGradients1(Nd4jBackend backend) {
|
public void testReductionGradients1(Nd4jBackend backend) {
|
||||||
//Test reductions: final, but *not* the only function
|
//Test reductions: final, but *not* the only function
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -477,9 +472,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
return Long.MAX_VALUE;
|
return Long.MAX_VALUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReductionGradients2(Nd4jBackend backend) {
|
public void testReductionGradients2(Nd4jBackend backend) {
|
||||||
//Test reductions: NON-final function
|
//Test reductions: NON-final function
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -657,9 +651,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReduce3(Nd4jBackend backend) {
|
public void testReduce3(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -764,9 +757,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),"Failed: " + failed);
|
assertEquals(0, failed.size(),"Failed: " + failed);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMoments(Nd4jBackend backend) {
|
public void testMoments(Nd4jBackend backend) {
|
||||||
for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) {
|
for (int[] axes : new int[][]{{0}, {1}, {0, 1}}) {
|
||||||
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||||
|
@ -798,9 +790,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMomentsOp(Nd4jBackend backend) {
|
public void testMomentsOp(Nd4jBackend backend) {
|
||||||
int[] axes = new int[]{0};
|
int[] axes = new int[]{0};
|
||||||
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
INDArray input = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||||
|
@ -817,9 +808,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNormalizeMomentsOp(Nd4jBackend backend) {
|
public void testNormalizeMomentsOp(Nd4jBackend backend) {
|
||||||
INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
|
INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
|
||||||
INDArray ssSum = data.sum(0);
|
INDArray ssSum = data.sum(0);
|
||||||
|
@ -839,9 +829,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAllAny(Nd4jBackend backend) {
|
public void testAllAny(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4);
|
INDArray allZeros = Nd4j.zeros(DataType.FLOAT, 3, 4);
|
||||||
|
@ -869,9 +858,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIndexAccum(Nd4jBackend backend) {
|
public void testIndexAccum(Nd4jBackend backend) {
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/);
|
List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/);
|
||||||
|
@ -960,9 +948,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReduce3_2(Nd4jBackend backend) {
|
public void testReduce3_2(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -1060,9 +1047,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReductionsBackwards(Nd4jBackend backend) {
|
public void testReductionsBackwards(Nd4jBackend backend) {
|
||||||
// for (int i = 0; i < 7; i++) {
|
// for (int i = 0; i < 7; i++) {
|
||||||
int i=5;
|
int i=5;
|
||||||
|
@ -1131,9 +1117,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDotProductAttention(){
|
public void testDotProductAttention(){
|
||||||
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
|
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
|
||||||
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
|
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
|
||||||
|
@ -1158,9 +1143,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDotProductAttentionWithMask(){
|
public void testDotProductAttentionWithMask(){
|
||||||
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
|
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
|
||||||
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
|
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
|
||||||
|
@ -1190,9 +1174,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDotProductAttentionMultiHeadInputWithMask(){
|
public void testDotProductAttentionMultiHeadInputWithMask(){
|
||||||
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
|
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
|
||||||
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
|
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
|
||||||
|
@ -1223,9 +1206,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDotProductAttentionMultiHeadInput(){
|
public void testDotProductAttentionMultiHeadInput(){
|
||||||
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
|
final INDArray keys = Nd4j.rand(new int[]{2, 5, 4, 3});
|
||||||
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
|
final INDArray values = Nd4j.rand(new int[]{2, 5, 4, 3});
|
||||||
|
@ -1252,9 +1234,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiHeadedDotProductAttention(){
|
public void testMultiHeadedDotProductAttention(){
|
||||||
final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
|
final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
|
||||||
final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
|
final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
|
||||||
|
@ -1305,9 +1286,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDotProductAttentionWeirdInputs(){
|
public void testDotProductAttentionWeirdInputs(){
|
||||||
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
|
final INDArray keys = Nd4j.rand(new int[]{10, 4, 3});
|
||||||
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
|
final INDArray values = Nd4j.rand(new int[]{10, 4, 3});
|
||||||
|
@ -1344,9 +1324,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiHeadedDotProductAttentionWeirdInputs(){
|
public void testMultiHeadedDotProductAttentionWeirdInputs(){
|
||||||
final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
|
final INDArray k = Nd4j.rand(new int[]{10, 4, 5});
|
||||||
final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
|
final INDArray v = Nd4j.rand(new int[]{10, 4, 5});
|
||||||
|
@ -1403,9 +1382,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSufficientStatisticsOp(Nd4jBackend backend) {
|
public void testSufficientStatisticsOp(Nd4jBackend backend) {
|
||||||
INDArray data = Nd4j.createFromArray(new double[]{
|
INDArray data = Nd4j.createFromArray(new double[]{
|
||||||
5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
|
5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
|
||||||
|
@ -1431,9 +1409,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStandardDeviation(Nd4jBackend backend) {
|
public void testStandardDeviation(Nd4jBackend backend) {
|
||||||
|
|
||||||
for (boolean keepDims : new boolean[]{false, true}) {
|
for (boolean keepDims : new boolean[]{false, true}) {
|
||||||
|
@ -1460,9 +1437,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSquaredNorm(Nd4jBackend backend) {
|
public void testSquaredNorm(Nd4jBackend backend) {
|
||||||
|
|
||||||
for (boolean keepDims : new boolean[]{false, true}) {
|
for (boolean keepDims : new boolean[]{false, true}) {
|
||||||
|
@ -1485,9 +1461,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testShannonEntropy(Nd4jBackend backend) {
|
public void testShannonEntropy(Nd4jBackend backend) {
|
||||||
OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
|
OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
|
||||||
|
|
||||||
|
@ -1507,9 +1482,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEntropy(Nd4jBackend backend) {
|
public void testEntropy(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -1528,9 +1502,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAMean(Nd4jBackend backend) {
|
public void testAMean(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -1551,9 +1524,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMean(Nd4jBackend backend) {
|
public void testMean(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -1574,9 +1546,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNorm1(Nd4jBackend backend) {
|
public void testNorm1(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -1597,9 +1568,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNorm2(Nd4jBackend backend) {
|
public void testNorm2(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -1620,9 +1590,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNormMax(Nd4jBackend backend) {
|
public void testNormMax(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
@ -1643,9 +1612,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSoftmaxCrossEntropyWithLogitsLoss(Nd4jBackend backend) {
|
public void testSoftmaxCrossEntropyWithLogitsLoss(Nd4jBackend backend) {
|
||||||
OpValidationSuite.ignoreFailing();
|
OpValidationSuite.ignoreFailing();
|
||||||
|
|
||||||
|
|
|
@ -46,9 +46,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class RnnOpValidation extends BaseOpValidation {
|
public class RnnOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRnnBlockCell(Nd4jBackend backend) {
|
public void testRnnBlockCell(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int mb = 2;
|
int mb = 2;
|
||||||
|
@ -147,9 +146,8 @@ public class RnnOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRnnBlockCellManualTFCompare(Nd4jBackend backend) {
|
public void testRnnBlockCellManualTFCompare(Nd4jBackend backend) {
|
||||||
//Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS"
|
//Test case: "rnn/lstmblockcell/static_batch1_n3-2_tsLength1_noPH_noClip_fBias1_noIS"
|
||||||
|
|
||||||
|
@ -211,9 +209,8 @@ public class RnnOpValidation extends BaseOpValidation {
|
||||||
assertEquals(out6, m.get(toExec.get(6))); //Output
|
assertEquals(out6, m.get(toExec.get(6))); //Output
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGRUCell(){
|
public void testGRUCell(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int mb = 2;
|
int mb = 2;
|
||||||
|
|
|
@ -81,9 +81,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
doRepeat
|
doRepeat
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConcat(Nd4jBackend backend) {
|
public void testConcat(Nd4jBackend backend) {
|
||||||
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
|
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
|
||||||
int[] concatDim = new int[]{0, 0, 0};
|
int[] concatDim = new int[]{0, 0, 0};
|
||||||
|
@ -123,9 +122,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals( 0, failed.size(),failed.toString());
|
assertEquals( 0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReshapeGradient(Nd4jBackend backend) {
|
public void testReshapeGradient(Nd4jBackend backend) {
|
||||||
//https://github.com/deeplearning4j/deeplearning4j/issues/6873
|
//https://github.com/deeplearning4j/deeplearning4j/issues/6873
|
||||||
|
|
||||||
|
@ -161,9 +159,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPermuteGradient(Nd4jBackend backend) {
|
public void testPermuteGradient(Nd4jBackend backend) {
|
||||||
int[] origShape = new int[]{3, 4, 5};
|
int[] origShape = new int[]{3, 4, 5};
|
||||||
|
|
||||||
|
@ -201,9 +198,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRank(){
|
public void testRank(){
|
||||||
|
|
||||||
List<long[]> inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5});
|
List<long[]> inShape = Arrays.asList(null, new long[]{1}, new long[]{6}, new long[]{3,4}, new long[]{3,4,5});
|
||||||
|
@ -230,9 +226,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testExpandDimsGradient(Nd4jBackend backend) {
|
public void testExpandDimsGradient(Nd4jBackend backend) {
|
||||||
val origShape = new long[]{3, 4};
|
val origShape = new long[]{3, 4};
|
||||||
|
|
||||||
|
@ -288,9 +283,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSqueezeGradient(Nd4jBackend backend) {
|
public void testSqueezeGradient(Nd4jBackend backend) {
|
||||||
val origShape = new long[]{3, 4, 5};
|
val origShape = new long[]{3, 4, 5};
|
||||||
|
|
||||||
|
@ -354,9 +348,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSliceGradient(Nd4jBackend backend) {
|
public void testSliceGradient(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -446,9 +439,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStridedSliceGradient(Nd4jBackend backend) {
|
public void testStridedSliceGradient(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -511,9 +503,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMerge(Nd4jBackend backend) {
|
public void testMerge(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -680,9 +671,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUnStack(Nd4jBackend backend) {
|
public void testUnStack(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -770,9 +760,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals( 0, failed.size(),failed.toString());
|
assertEquals( 0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTile(Nd4jBackend backend) {
|
public void testTile(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -844,9 +833,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTileBp(){
|
public void testTileBp(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -879,9 +867,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTileBp2(){
|
public void testTileBp2(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -915,9 +902,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReshape(Nd4jBackend backend) {
|
public void testReshape(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4);
|
INDArray arr = Transforms.sigmoid(Nd4j.linspace(-5, 6, 12)).reshape(3, 4);
|
||||||
|
@ -933,9 +919,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReshape2(Nd4jBackend backend) {
|
public void testReshape2(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int[] origShape = new int[]{3, 4, 5};
|
int[] origShape = new int[]{3, 4, 5};
|
||||||
|
@ -958,9 +943,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTranspose(Nd4jBackend backend) {
|
public void testTranspose(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
|
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
|
||||||
|
@ -972,9 +956,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTransposeOp(){
|
public void testTransposeOp(){
|
||||||
|
|
||||||
INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3);
|
INDArray arr = Nd4j.linspace(1,15, 15).reshape(5,3);
|
||||||
|
@ -987,9 +970,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testShape(Nd4jBackend backend) {
|
public void testShape(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
val shape = new long[]{2, 3};
|
val shape = new long[]{2, 3};
|
||||||
|
@ -1004,9 +986,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSize(Nd4jBackend backend) {
|
public void testSize(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
val shape = new long[]{2, 3};
|
val shape = new long[]{2, 3};
|
||||||
|
@ -1020,9 +1001,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDiagShapeFn(Nd4jBackend backend) {
|
public void testDiagShapeFn(Nd4jBackend backend) {
|
||||||
INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4);
|
INDArray i = Nd4j.linspace(1, 16, 16).reshape(4,4);
|
||||||
|
|
||||||
|
@ -1036,9 +1016,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPermute(){
|
public void testPermute(){
|
||||||
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
|
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
|
||||||
INDArray exp = in.permute(0,1,2); //No op
|
INDArray exp = in.permute(0,1,2); //No op
|
||||||
|
@ -1052,9 +1031,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(OpValidation.validate(op));
|
assertNull(OpValidation.validate(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPermute2(){
|
public void testPermute2(){
|
||||||
for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) {
|
for (int[] perm : new int[][]{{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}}) {
|
||||||
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
|
INDArray in = Nd4j.linspace(1, 60, 60).reshape(3,4,5);
|
||||||
|
@ -1074,9 +1052,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConstant(){
|
public void testConstant(){
|
||||||
//OpValidationSuite.ignoreFailing();
|
//OpValidationSuite.ignoreFailing();
|
||||||
|
|
||||||
|
@ -1103,9 +1080,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUnstackEdgeCase2(){
|
public void testUnstackEdgeCase2(){
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i=0; i<3; i++ ) {
|
||||||
|
|
||||||
|
@ -1119,9 +1095,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void invertPermutation(Nd4jBackend backend) {
|
public void invertPermutation(Nd4jBackend backend) {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
@ -1138,9 +1113,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGatherNd(){
|
public void testGatherNd(){
|
||||||
|
|
||||||
List<INDArray> indices = new ArrayList<>();
|
List<INDArray> indices = new ArrayList<>();
|
||||||
|
@ -1178,9 +1152,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReverseSequence(Nd4jBackend backend) {
|
public void testReverseSequence(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
float[] input_data = new float[]{
|
float[] input_data = new float[]{
|
||||||
|
@ -1226,9 +1199,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMatrixDeterminant(){
|
public void testMatrixDeterminant(){
|
||||||
OpValidationSuite.ignoreFailing(); //Gradient check failing
|
OpValidationSuite.ignoreFailing(); //Gradient check failing
|
||||||
|
|
||||||
|
@ -1249,9 +1221,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDeterminant22(){
|
public void testDeterminant22(){
|
||||||
OpValidationSuite.ignoreFailing(); //Gradient check failing
|
OpValidationSuite.ignoreFailing(); //Gradient check failing
|
||||||
|
|
||||||
|
@ -1275,9 +1246,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMatrixDeterminant3(){
|
public void testMatrixDeterminant3(){
|
||||||
OpValidationSuite.ignoreFailing(); //Gradient checks failing
|
OpValidationSuite.ignoreFailing(); //Gradient checks failing
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -1308,9 +1278,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMatrixDeterminant4(){
|
public void testMatrixDeterminant4(){
|
||||||
OpValidationSuite.ignoreFailing(); //Gradient checks failing
|
OpValidationSuite.ignoreFailing(); //Gradient checks failing
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -1330,9 +1299,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSegmentOps(){
|
public void testSegmentOps(){
|
||||||
OpValidationSuite.ignoreFailing();
|
OpValidationSuite.ignoreFailing();
|
||||||
//https://github.com/deeplearning4j/deeplearning4j/issues/6952
|
//https://github.com/deeplearning4j/deeplearning4j/issues/6952
|
||||||
|
@ -1424,9 +1392,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSegmentMean(){
|
public void testSegmentMean(){
|
||||||
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3);
|
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3);
|
||||||
INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2);
|
INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2);
|
||||||
|
@ -1446,9 +1413,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, out);
|
assertEquals(exp, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSequenceMask(Nd4jBackend backend) {
|
public void testSequenceMask(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
|
INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
|
||||||
|
@ -1482,9 +1448,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(expected, result2.eval());
|
assertEquals(expected, result2.eval());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMeshGrid(){
|
public void testMeshGrid(){
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -1540,9 +1505,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGather(){
|
public void testGather(){
|
||||||
List<INDArray> inArrs = new ArrayList<>();
|
List<INDArray> inArrs = new ArrayList<>();
|
||||||
List<Integer> axis = new ArrayList<>();
|
List<Integer> axis = new ArrayList<>();
|
||||||
|
@ -1611,9 +1575,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGatherSimple(Nd4jBackend backend) {
|
public void testGatherSimple(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2});
|
INDArray arr = Nd4j.create(new float[]{1, 2, 3, 4}, new long[]{2, 2});
|
||||||
|
@ -1623,9 +1586,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(expected, result.eval());
|
assertEquals(expected, result.eval());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGatherNdSingle(Nd4jBackend backend) {
|
public void testGatherNdSingle(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4);
|
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(DataType.DOUBLE, 1, 24, 24)).reshape(2, 3, 4);
|
||||||
|
@ -1644,9 +1606,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(expected, result.eval());
|
assertEquals(expected, result.eval());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStack2(Nd4jBackend backend) {
|
public void testStack2(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
|
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
|
||||||
|
@ -1657,9 +1618,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(new long[]{3, 2, 2}, result.eval().shape());
|
assertArrayEquals(new long[]{3, 2, 2}, result.eval().shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testParallelStack(Nd4jBackend backend) {
|
public void testParallelStack(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
|
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 6, 6)).reshape(3, 2);
|
||||||
|
@ -1671,9 +1631,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(Nd4j.concat(0, arr1, arr2).reshape(2, 3, 2), result.eval());
|
assertEquals(Nd4j.concat(0, arr1, arr2).reshape(2, 3, 2), result.eval());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUnStack2(Nd4jBackend backend) {
|
public void testUnStack2(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray arr1 = Nd4j.zeros(3, 2);
|
INDArray arr1 = Nd4j.zeros(3, 2);
|
||||||
|
@ -1686,9 +1645,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(arr2, result[1].eval());
|
assertEquals(arr2, result[1].eval());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPermuteSimple(Nd4jBackend backend) {
|
public void testPermuteSimple(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3));
|
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3));
|
||||||
|
@ -1699,9 +1657,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConcat2(Nd4jBackend backend) {
|
public void testConcat2(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
|
INDArray arr1 = Transforms.sigmoid(Nd4j.linspace(1, 4, 4)).reshape(1,4);
|
||||||
|
@ -1712,9 +1669,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(new long[]{2, 4}, result.eval().shape());
|
assertArrayEquals(new long[]{2, 4}, result.eval().shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTile2(Nd4jBackend backend) {
|
public void testTile2(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4));
|
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1,4));
|
||||||
|
@ -1727,9 +1683,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSlice2d(Nd4jBackend backend) {
|
public void testSlice2d(Nd4jBackend backend) {
|
||||||
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
|
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
|
||||||
|
|
||||||
|
@ -1745,9 +1700,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSlice3d(Nd4jBackend backend) {
|
public void testSlice3d(Nd4jBackend backend) {
|
||||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||||
|
|
||||||
|
@ -1762,9 +1716,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), m.get(subPart.name()));
|
assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), m.get(subPart.name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStridedSlice2dBasic(Nd4jBackend backend) {
|
public void testStridedSlice2dBasic(Nd4jBackend backend) {
|
||||||
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
|
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
|
||||||
|
|
||||||
|
@ -1782,9 +1735,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStridedSliceBeginEndMask(Nd4jBackend backend) {
|
public void testStridedSliceBeginEndMask(Nd4jBackend backend) {
|
||||||
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
|
INDArray inArr = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4);
|
||||||
|
|
||||||
|
@ -1799,9 +1751,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(inArr.get(NDArrayIndex.interval(1, 3), NDArrayIndex.all()), slice2.getArr());
|
assertEquals(inArr.get(NDArrayIndex.interval(1, 3), NDArrayIndex.all()), slice2.getArr());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStridedSliceEllipsisMask(Nd4jBackend backend) {
|
public void testStridedSliceEllipsisMask(Nd4jBackend backend) {
|
||||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -1818,9 +1769,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(inArr.get(interval(1, 3), all(), all()), slice2.getArr());
|
assertEquals(inArr.get(interval(1, 3), all(), all()), slice2.getArr());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStridedSliceNewAxisMask(Nd4jBackend backend) {
|
public void testStridedSliceNewAxisMask(Nd4jBackend backend) {
|
||||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -1833,9 +1783,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(inArr, out.get(point(0), all(), all(), all()));
|
assertEquals(inArr, out.get(point(0), all(), all(), all()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStridedSliceNewAxisMask2(Nd4jBackend backend) {
|
public void testStridedSliceNewAxisMask2(Nd4jBackend backend) {
|
||||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -1846,9 +1795,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape());
|
assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStridedSliceShrinkAxisMask(Nd4jBackend backend) {
|
public void testStridedSliceShrinkAxisMask(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
INDArray inArr = Nd4j.linspace(1, 60, 60).reshape('c', 3, 4, 5);
|
||||||
|
@ -1865,9 +1813,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(inArr.get(point(1), point(2), interval(1, 5)).reshape(4), slice3.getArr());
|
assertEquals(inArr.get(point(1), point(2), interval(1, 5)).reshape(4), slice3.getArr());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSizeAt_1(Nd4jBackend backend) {
|
public void testSizeAt_1(Nd4jBackend backend) {
|
||||||
val array = Nd4j.create(10, 20, 30);
|
val array = Nd4j.create(10, 20, 30);
|
||||||
val exp = Nd4j.scalar(DataType.LONG, 20);
|
val exp = Nd4j.scalar(DataType.LONG, 20);
|
||||||
|
@ -1881,9 +1828,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, output);
|
assertEquals(exp, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEye(){
|
public void testEye(){
|
||||||
int[] rows = new int[]{3,3,3,3};
|
int[] rows = new int[]{3,3,3,3};
|
||||||
int[] cols = new int[]{3,2,2,2};
|
int[] cols = new int[]{3,2,2,2};
|
||||||
|
@ -1921,9 +1867,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSplit1(){
|
public void testSplit1(){
|
||||||
INDArray in = Nd4j.linspace(1,10,10).reshape(10);
|
INDArray in = Nd4j.linspace(1,10,10).reshape(10);
|
||||||
INDArray axis = Nd4j.scalar(-1);
|
INDArray axis = Nd4j.scalar(-1);
|
||||||
|
@ -1941,9 +1886,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
.build()).expectedOutput(0, exp1).expectedOutput(1,exp2)));
|
.build()).expectedOutput(0, exp1).expectedOutput(1,exp2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSplit2(){
|
public void testSplit2(){
|
||||||
INDArray in = Nd4j.linspace(1,24,24).reshape(3,8);
|
INDArray in = Nd4j.linspace(1,24,24).reshape(3,8);
|
||||||
INDArray axis = Nd4j.scalar(-1);
|
INDArray axis = Nd4j.scalar(-1);
|
||||||
|
@ -1961,9 +1905,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
.build()).expectedOutput(0, exp1).expectedOutput(1,exp2)));
|
.build()).expectedOutput(0, exp1).expectedOutput(1,exp2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDistancesExec(){
|
public void testDistancesExec(){
|
||||||
//https://github.com/deeplearning4j/deeplearning4j/issues/7001
|
//https://github.com/deeplearning4j/deeplearning4j/issues/7001
|
||||||
for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) {
|
for(String s : new String[]{"euclidean", "manhattan", "cosinesim", "cosinedist", "jaccard"}) {
|
||||||
|
@ -2018,9 +1961,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReductionShape(){
|
public void testReductionShape(){
|
||||||
|
|
||||||
INDArray shape = Nd4j.createFromArray(4,2);
|
INDArray shape = Nd4j.createFromArray(4,2);
|
||||||
|
@ -2038,9 +1980,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(exp, s); //Fails - actual shape [1]
|
assertArrayEquals(exp, s); //Fails - actual shape [1]
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void gatherTest(){
|
public void gatherTest(){
|
||||||
INDArray in = Nd4j.createFromArray(new double[][]{
|
INDArray in = Nd4j.createFromArray(new double[][]{
|
||||||
{1,2,3,4,5},
|
{1,2,3,4,5},
|
||||||
|
@ -2059,9 +2000,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(expShape, shape); //Fails: actual shape: [5]
|
assertArrayEquals(expShape, shape); //Fails: actual shape: [5]
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSliceShape(){
|
public void testSliceShape(){
|
||||||
|
|
||||||
INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT);
|
INDArray arr = Nd4j.arange(0, 25).reshape(1,5,5).castTo(DataType.INT);
|
||||||
|
@ -2082,9 +2022,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(shapeExp, shape);
|
assertArrayEquals(shapeExp, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testWhereAllFalse(){
|
public void testWhereAllFalse(){
|
||||||
INDArray in = Nd4j.create(DataType.BOOL, 1917);
|
INDArray in = Nd4j.create(DataType.BOOL, 1917);
|
||||||
DynamicCustomOp op = DynamicCustomOp.builder("Where")
|
DynamicCustomOp op = DynamicCustomOp.builder("Where")
|
||||||
|
@ -2098,9 +2037,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertTrue(isEmpty); //Not empty, but should be
|
assertTrue(isEmpty); //Not empty, but should be
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGatherScalar(){
|
public void testGatherScalar(){
|
||||||
INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100);
|
INDArray in = Nd4j.linspace(100, 200, 100, DataType.FLOAT).reshape(100);
|
||||||
INDArray indices = Nd4j.scalar(0);
|
INDArray indices = Nd4j.scalar(0);
|
||||||
|
@ -2124,9 +2062,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, arr);
|
assertEquals(exp, arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCastEmpty(){
|
public void testCastEmpty(){
|
||||||
INDArray emptyLong = Nd4j.empty(DataType.LONG);
|
INDArray emptyLong = Nd4j.empty(DataType.LONG);
|
||||||
int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h
|
int dtype = 9; //INT = 9 - https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/array/DataType.h
|
||||||
|
@ -2142,9 +2079,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertTrue(isEmpty);
|
assertTrue(isEmpty);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGatherEmpty(){
|
public void testGatherEmpty(){
|
||||||
/*
|
/*
|
||||||
tf.reset_default_graph()
|
tf.reset_default_graph()
|
||||||
|
@ -2176,9 +2112,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertTrue(isEmpty);
|
assertTrue(isEmpty);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSplitEmpty(){
|
public void testSplitEmpty(){
|
||||||
/*
|
/*
|
||||||
tf.reset_default_graph()
|
tf.reset_default_graph()
|
||||||
|
@ -2215,9 +2150,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConcatEmpty(){
|
public void testConcatEmpty(){
|
||||||
/*
|
/*
|
||||||
TF behaviour with concatenatioun of empty arrays:
|
TF behaviour with concatenatioun of empty arrays:
|
||||||
|
@ -2266,9 +2200,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConcatEmpty2(){
|
public void testConcatEmpty2(){
|
||||||
INDArray empty10a = Nd4j.create(DataType.INT, 1, 0);
|
INDArray empty10a = Nd4j.create(DataType.INT, 1, 0);
|
||||||
INDArray empty10b = Nd4j.create(DataType.INT, 1, 0);
|
INDArray empty10b = Nd4j.create(DataType.INT, 1, 0);
|
||||||
|
@ -2300,9 +2233,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEmptyGather(){
|
public void testEmptyGather(){
|
||||||
/*
|
/*
|
||||||
tf.reset_default_graph()
|
tf.reset_default_graph()
|
||||||
|
@ -2334,9 +2266,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
op.addOutputArgument(out);
|
op.addOutputArgument(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBroadcastDynamicShape1(){
|
public void testBroadcastDynamicShape1(){
|
||||||
|
|
||||||
//Test case: [2,1] and [4]: expect [2,4]
|
//Test case: [2,1] and [4]: expect [2,4]
|
||||||
|
@ -2357,9 +2288,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(Nd4j.createFromArray(new int[]{2,4}), out);
|
assertEquals(Nd4j.createFromArray(new int[]{2,4}), out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBroadcastDynamicShape2(){
|
public void testBroadcastDynamicShape2(){
|
||||||
|
|
||||||
//Test case: [2,1,4] and [2,2,4]: expect [2,2,4]
|
//Test case: [2,1,4] and [2,2,4]: expect [2,2,4]
|
||||||
|
@ -2381,9 +2311,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(Nd4j.createFromArray(new int[]{2,4,3}), out);
|
assertEquals(Nd4j.createFromArray(new int[]{2,4,3}), out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStridedSliceShrinkAxis(){
|
public void testStridedSliceShrinkAxis(){
|
||||||
INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2);
|
INDArray in = Nd4j.create(DataType.DOUBLE, 3,2,2);
|
||||||
INDArray begin = Nd4j.createFromArray(2);
|
INDArray begin = Nd4j.createFromArray(2);
|
||||||
|
@ -2408,9 +2337,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(exp, shape);
|
assertArrayEquals(exp, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStridedSliceEmpty(){
|
public void testStridedSliceEmpty(){
|
||||||
|
|
||||||
INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask!
|
INDArray in = Nd4j.createFromArray(10); //Integer, Length 1, rank 1, value 10 - Not used due to begin mask!
|
||||||
|
@ -2432,9 +2360,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertTrue(isEmpty);
|
assertTrue(isEmpty);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStridedSliceEdgeCase(){
|
public void testStridedSliceEdgeCase(){
|
||||||
INDArray in = Nd4j.scalar(10).reshape(1); //Int [1]
|
INDArray in = Nd4j.scalar(10).reshape(1); //Int [1]
|
||||||
INDArray begin = Nd4j.ones(DataType.INT, 1);
|
INDArray begin = Nd4j.ones(DataType.INT, 1);
|
||||||
|
@ -2459,9 +2386,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
Nd4j.exec(op); //Execution is OK
|
Nd4j.exec(op); //Execution is OK
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEmptySlice1(){
|
public void testEmptySlice1(){
|
||||||
INDArray in = Nd4j.createFromArray(38);
|
INDArray in = Nd4j.createFromArray(38);
|
||||||
INDArray begin = Nd4j.createFromArray(1);
|
INDArray begin = Nd4j.createFromArray(1);
|
||||||
|
@ -2480,9 +2406,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEmptySlice2(){
|
public void testEmptySlice2(){
|
||||||
INDArray in = Nd4j.createFromArray(38);
|
INDArray in = Nd4j.createFromArray(38);
|
||||||
INDArray begin = Nd4j.createFromArray(0);
|
INDArray begin = Nd4j.createFromArray(0);
|
||||||
|
@ -2501,9 +2426,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testFill(){
|
public void testFill(){
|
||||||
|
|
||||||
INDArray shape = Nd4j.createFromArray(0,4);
|
INDArray shape = Nd4j.createFromArray(0,4);
|
||||||
|
@ -2522,9 +2446,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testFill2(){
|
public void testFill2(){
|
||||||
|
|
||||||
INDArray shape = Nd4j.createFromArray(0,4);
|
INDArray shape = Nd4j.createFromArray(0,4);
|
||||||
|
@ -2541,9 +2464,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPermuteShapeDynamicAxis(){
|
public void testPermuteShapeDynamicAxis(){
|
||||||
|
|
||||||
DynamicCustomOp op = DynamicCustomOp.builder("permute")
|
DynamicCustomOp op = DynamicCustomOp.builder("permute")
|
||||||
|
@ -2572,9 +2494,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(new long[]{4, 5, 3}, l.get(0).getShape());
|
assertArrayEquals(new long[]{4, 5, 3}, l.get(0).getShape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGather2(){
|
public void testGather2(){
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3));
|
SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3));
|
||||||
|
@ -2593,9 +2514,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPermute3(){
|
public void testPermute3(){
|
||||||
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
|
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
|
||||||
INDArray permute = Nd4j.createFromArray(1,0);
|
INDArray permute = Nd4j.createFromArray(1,0);
|
||||||
|
@ -2613,9 +2533,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp, outArr);
|
assertEquals(exp, outArr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPermute4(){
|
public void testPermute4(){
|
||||||
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
|
INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2);
|
||||||
INDArray permute = Nd4j.createFromArray(1,0);
|
INDArray permute = Nd4j.createFromArray(1,0);
|
||||||
|
@ -2645,18 +2564,16 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInvertPermutation(){
|
public void testInvertPermutation(){
|
||||||
DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation")
|
DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation")
|
||||||
.addInputs(Nd4j.createFromArray(1, 0))
|
.addInputs(Nd4j.createFromArray(1, 0))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBroadcastInt1(Nd4jBackend backend) {
|
public void testBroadcastInt1(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray out = Nd4j.create(DataType.INT, 1);
|
INDArray out = Nd4j.create(DataType.INT, 1);
|
||||||
|
@ -2669,9 +2586,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBroadcastInt2(){
|
public void testBroadcastInt2(){
|
||||||
INDArray out = Nd4j.create(DataType.INT, 2);
|
INDArray out = Nd4j.create(DataType.INT, 2);
|
||||||
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
|
DynamicCustomOp op = DynamicCustomOp.builder("broadcast_dynamic_shape")
|
||||||
|
@ -2710,9 +2626,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMergeMaxIndex(Nd4jBackend backend) {
|
public void testMergeMaxIndex(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -2729,9 +2644,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTriOp(Nd4jBackend backend) {
|
public void testTriOp(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -2743,9 +2657,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTriuOp(Nd4jBackend backend) {
|
public void testTriuOp(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
|
@ -118,9 +118,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScalarOps(Nd4jBackend backend) {
|
public void testScalarOps(Nd4jBackend backend) {
|
||||||
int d0 = 2;
|
int d0 = 2;
|
||||||
int d1 = 3;
|
int d1 = 3;
|
||||||
|
@ -217,9 +216,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScalarMulCF(Nd4jBackend backend) {
|
public void testScalarMulCF(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
|
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
|
||||||
|
@ -233,9 +231,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScalarMulCF2(Nd4jBackend backend) {
|
public void testScalarMulCF2(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
|
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
|
||||||
|
@ -246,9 +243,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertEquals(outC, outF);
|
assertEquals(outC, outF);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCross(Nd4jBackend backend) {
|
public void testCross(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3});
|
INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3});
|
||||||
INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3});
|
INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3});
|
||||||
|
@ -276,9 +272,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSpaceToDepth(Nd4jBackend backend) {
|
public void testSpaceToDepth(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(1337);
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
||||||
|
@ -306,9 +301,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDepthToSpace(Nd4jBackend backend) {
|
public void testDepthToSpace(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(1337);
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
||||||
|
@ -335,9 +329,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBatchToSpace(Nd4jBackend backend) {
|
public void testBatchToSpace(Nd4jBackend backend) {
|
||||||
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
|
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
|
||||||
Nd4j.getRandom().setSeed(1337);
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
@ -374,9 +367,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSpaceToBatch(Nd4jBackend backend) {
|
public void testSpaceToBatch(Nd4jBackend backend) {
|
||||||
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
|
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/eclipse/deeplearning4j/issues/6863
|
||||||
|
|
||||||
|
@ -414,9 +406,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDynamicPartition(Nd4jBackend backend) {
|
public void testDynamicPartition(Nd4jBackend backend) {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
@ -456,9 +447,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDynamicPartition2(Nd4jBackend backend) {
|
public void testDynamicPartition2(Nd4jBackend backend) {
|
||||||
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
|
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
|
||||||
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
|
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
|
||||||
|
@ -476,9 +466,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertEquals(exp2, out[2]);
|
assertEquals(exp2, out[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDynamicStitch(Nd4jBackend backend) {
|
public void testDynamicStitch(Nd4jBackend backend) {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
@ -515,9 +504,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDiag(Nd4jBackend backend) {
|
public void testDiag(Nd4jBackend backend) {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
@ -543,9 +531,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDiagPart(Nd4jBackend backend) {
|
public void testDiagPart(Nd4jBackend backend) {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
@ -564,9 +551,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEye(Nd4jBackend backend) {
|
public void testEye(Nd4jBackend backend) {
|
||||||
int[] rows = new int[]{3, 3, 3, 3};
|
int[] rows = new int[]{3, 3, 3, 3};
|
||||||
int[] cols = new int[]{3, 2, 2, 2};
|
int[] cols = new int[]{3, 2, 2, 2};
|
||||||
|
@ -600,9 +586,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEyeShape(Nd4jBackend backend) {
|
public void testEyeShape(Nd4jBackend backend) {
|
||||||
DynamicCustomOp dco = DynamicCustomOp.builder("eye")
|
DynamicCustomOp dco = DynamicCustomOp.builder("eye")
|
||||||
.addIntegerArguments(3, 3)
|
.addIntegerArguments(3, 3)
|
||||||
|
@ -614,9 +599,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(new long[]{3, 3}, list.get(0).getShape());
|
assertArrayEquals(new long[]{3, 3}, list.get(0).getShape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTransforms(Nd4jBackend backend) {
|
public void testTransforms(Nd4jBackend backend) {
|
||||||
//Test transforms (non-pairwise)
|
//Test transforms (non-pairwise)
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -1104,9 +1088,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPairwiseTransforms(Nd4jBackend backend) {
|
public void testPairwiseTransforms(Nd4jBackend backend) {
|
||||||
/*
|
/*
|
||||||
add, sub, mul, div, rsub, rdiv
|
add, sub, mul, div, rsub, rdiv
|
||||||
|
@ -1290,9 +1273,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIsX(Nd4jBackend backend) {
|
public void testIsX(Nd4jBackend backend) {
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -1347,9 +1329,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertEquals(0, failed.size(),failed.toString());
|
assertEquals(0, failed.size(),failed.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReplaceWhereScalar(Nd4jBackend backend) {
|
public void testReplaceWhereScalar(Nd4jBackend backend) {
|
||||||
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
|
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
|
||||||
|
|
||||||
|
@ -1371,9 +1352,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReplaceWhereArray(Nd4jBackend backend) {
|
public void testReplaceWhereArray(Nd4jBackend backend) {
|
||||||
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
|
for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) {
|
||||||
|
|
||||||
|
@ -1396,9 +1376,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLogGrad(Nd4jBackend backend) {
|
public void testLogGrad(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE));
|
SDVariable input = sameDiff.var("x", Nd4j.linspace(1, 4, 4, DataType.DOUBLE));
|
||||||
|
@ -1409,9 +1388,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSigmoidBackwards(Nd4jBackend backend) {
|
public void testSigmoidBackwards(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
|
@ -1429,9 +1407,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* @Test
|
/* @ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testDepth(Nd4jBackend backend) {
|
public void testDepth(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
SDVariable x = sameDiff.one("one",new long[]{2,2});
|
SDVariable x = sameDiff.one("one",new long[]{2,2});
|
||||||
|
@ -1440,9 +1417,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertEquals(1,sigmoid.depth());
|
assertEquals(1,sigmoid.depth());
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRank0EdgeCase(Nd4jBackend backend) {
|
public void testRank0EdgeCase(Nd4jBackend backend) {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4})));
|
SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4})));
|
||||||
|
@ -1455,9 +1431,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertEquals(4, d1, 0);
|
assertEquals(4, d1, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAtan2BroadcastShape(Nd4jBackend backend) {
|
public void testAtan2BroadcastShape(Nd4jBackend backend) {
|
||||||
INDArray arr1 = Nd4j.create(new long[]{3, 1, 4});
|
INDArray arr1 = Nd4j.create(new long[]{3, 1, 4});
|
||||||
INDArray arr2 = Nd4j.create(new long[]{1, 2, 4});
|
INDArray arr2 = Nd4j.create(new long[]{1, 2, 4});
|
||||||
|
@ -1472,9 +1447,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertArrayEquals(new long[]{3, 2, 4}, outShapes.get(0).getShape(),Arrays.toString(outShapes.get(0).getShape()));
|
assertArrayEquals(new long[]{3, 2, 4}, outShapes.get(0).getShape(),Arrays.toString(outShapes.get(0).getShape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBooleanAnd(Nd4jBackend backend) {
|
public void testBooleanAnd(Nd4jBackend backend) {
|
||||||
Nd4j.setDataType(DataType.FLOAT);
|
Nd4j.setDataType(DataType.FLOAT);
|
||||||
INDArray arr1 = Nd4j.create(new long[]{3, 4});
|
INDArray arr1 = Nd4j.create(new long[]{3, 4});
|
||||||
|
@ -1488,9 +1462,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
Nd4j.getExecutioner().exec(op);
|
Nd4j.getExecutioner().exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScatterOpsScalar(Nd4jBackend backend) {
|
public void testScatterOpsScalar(Nd4jBackend backend) {
|
||||||
for (String s : new String[]{"add", "sub", "mul", "div"}) {
|
for (String s : new String[]{"add", "sub", "mul", "div"}) {
|
||||||
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
|
INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3);
|
||||||
|
@ -1535,9 +1508,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
@Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540")
|
@Disabled("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540")
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPad(Nd4jBackend backend) {
|
public void testPad(Nd4jBackend backend) {
|
||||||
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
|
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
|
||||||
INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG);
|
INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG);
|
||||||
|
@ -1564,9 +1536,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMirrorPad(Nd4jBackend backend) {
|
public void testMirrorPad(Nd4jBackend backend) {
|
||||||
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
||||||
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
|
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
|
||||||
|
@ -1599,9 +1570,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err2);
|
assertNull(err2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMirrorPad2(Nd4jBackend backend) {
|
public void testMirrorPad2(Nd4jBackend backend) {
|
||||||
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
||||||
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
|
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT);
|
||||||
|
@ -1627,9 +1597,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMirrorPadSymmetric(Nd4jBackend backend) {
|
public void testMirrorPadSymmetric(Nd4jBackend backend) {
|
||||||
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4);
|
INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4);
|
||||||
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT);
|
INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT);
|
||||||
|
@ -1656,9 +1625,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUnique(Nd4jBackend backend) {
|
public void testUnique(Nd4jBackend backend) {
|
||||||
INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4});
|
INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4});
|
||||||
|
|
||||||
|
@ -1680,9 +1648,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTopK(Nd4jBackend backend) {
|
public void testTopK(Nd4jBackend backend) {
|
||||||
OpValidationSuite.ignoreFailing(); //Can't assume sorted here
|
OpValidationSuite.ignoreFailing(); //Can't assume sorted here
|
||||||
INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8});
|
INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8});
|
||||||
|
@ -1711,9 +1678,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTopK1(Nd4jBackend backend) {
|
public void testTopK1(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
|
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
|
||||||
INDArray k = Nd4j.scalar(1);
|
INDArray k = Nd4j.scalar(1);
|
||||||
|
@ -1734,9 +1700,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertEquals(expIdx, outIdx);
|
assertEquals(expIdx, outIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInTopK(Nd4jBackend backend) {
|
public void testInTopK(Nd4jBackend backend) {
|
||||||
for (int k = 4; k >= 1; k--) {
|
for (int k = 4; k >= 1; k--) {
|
||||||
log.info("Testing: k=" + k);
|
log.info("Testing: k=" + k);
|
||||||
|
@ -1777,9 +1742,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testZeta(Nd4jBackend backend) {
|
public void testZeta(Nd4jBackend backend) {
|
||||||
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182
|
OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182
|
||||||
INDArray x = Nd4j.rand(3, 4).addi(1.0);
|
INDArray x = Nd4j.rand(3, 4).addi(1.0);
|
||||||
|
@ -1796,9 +1760,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNotEquals(Nd4j.create(out.shape()), out);
|
assertNotEquals(Nd4j.create(out.shape()), out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMaxEmptyScalar(Nd4jBackend backend) {
|
public void testMaxEmptyScalar(Nd4jBackend backend) {
|
||||||
INDArray empty = Nd4j.empty(DataType.FLOAT);
|
INDArray empty = Nd4j.empty(DataType.FLOAT);
|
||||||
INDArray scalar = Nd4j.scalar(1.0f);
|
INDArray scalar = Nd4j.scalar(1.0f);
|
||||||
|
@ -1815,9 +1778,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertTrue(isEmpty);
|
assertTrue(isEmpty);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBroadcastEmpty(Nd4jBackend backend) {
|
public void testBroadcastEmpty(Nd4jBackend backend) {
|
||||||
// Nd4j.getExecutioner().enableVerboseMode(true);
|
// Nd4j.getExecutioner().enableVerboseMode(true);
|
||||||
// Nd4j.getExecutioner().enableDebugMode(true);
|
// Nd4j.getExecutioner().enableDebugMode(true);
|
||||||
|
@ -1907,9 +1869,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStandardize(Nd4jBackend backend) {
|
public void testStandardize(Nd4jBackend backend) {
|
||||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
||||||
|
|
||||||
|
@ -1930,9 +1891,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStandardizeOP(Nd4jBackend backend) {
|
public void testStandardizeOP(Nd4jBackend backend) {
|
||||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
||||||
|
|
||||||
|
@ -1947,9 +1907,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertEquals(res, output);
|
assertEquals(res, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStandardizeNoDeviation(Nd4jBackend backend) {
|
public void testStandardizeNoDeviation(Nd4jBackend backend) {
|
||||||
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
final INDArray random = Nd4j.rand(new int[]{10, 4});
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
|
@ -1975,9 +1934,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMatMulTensor(Nd4jBackend backend) {
|
public void testMatMulTensor(Nd4jBackend backend) {
|
||||||
final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5});
|
final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5});
|
||||||
final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6});
|
final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6});
|
||||||
|
@ -1997,9 +1955,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMatMulTensorTranspose(Nd4jBackend backend) {
|
public void testMatMulTensorTranspose(Nd4jBackend backend) {
|
||||||
for (boolean transposeA : new boolean[]{false, true}) {
|
for (boolean transposeA : new boolean[]{false, true}) {
|
||||||
for (boolean transposeB : new boolean[]{false, true}) {
|
for (boolean transposeB : new boolean[]{false, true}) {
|
||||||
|
@ -2092,9 +2049,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSoftmaxCF(Nd4jBackend backend) {
|
public void testSoftmaxCF(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5);
|
INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5);
|
||||||
|
@ -2115,9 +2071,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertEquals(outCC, outFF);
|
assertEquals(outCC, outFF);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLogSumExp(Nd4jBackend backend) {
|
public void testLogSumExp(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4);
|
INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4);
|
||||||
|
@ -2132,9 +2087,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertEquals(log, out);
|
assertEquals(log, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLogSumExp2(Nd4jBackend backend) {
|
public void testLogSumExp2(Nd4jBackend backend) {
|
||||||
|
|
||||||
for (int dim = 0; dim <= 2; dim++) {
|
for (int dim = 0; dim <= 2; dim++) {
|
||||||
|
@ -2155,9 +2109,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCRELU(Nd4jBackend backend) {
|
public void testCRELU(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -2176,9 +2129,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testClipByAvgNorm(Nd4jBackend backend) {
|
public void testClipByAvgNorm(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -2199,9 +2151,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEmbeddingLookup(Nd4jBackend backend) {
|
public void testEmbeddingLookup(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -2214,9 +2165,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testImageResize(Nd4jBackend backend) {
|
public void testImageResize(Nd4jBackend backend) {
|
||||||
|
|
||||||
//TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea
|
//TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea
|
||||||
|
@ -2258,9 +2208,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMaximumBp(Nd4jBackend backend) {
|
public void testMaximumBp(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -2277,9 +2226,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMergeAddBp(Nd4jBackend backend) {
|
public void testMergeAddBp(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -2296,9 +2244,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMergeMaxBp(Nd4jBackend backend) {
|
public void testMergeMaxBp(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -2316,9 +2263,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMergeAvgBp(Nd4jBackend backend) {
|
public void testMergeAvgBp(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -2335,9 +2281,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReverseBp(Nd4jBackend backend) {
|
public void testReverseBp(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -2351,9 +2296,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUpsampling3dBp(Nd4jBackend backend) {
|
public void testUpsampling3dBp(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -45,9 +45,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDeConv2D(Nd4jBackend backend){
|
public void testDeConv2D(Nd4jBackend backend){
|
||||||
DeConv2DConfig.builder().kH(2).kW(4).build();
|
DeConv2DConfig.builder().kH(2).kW(4).build();
|
||||||
|
|
||||||
|
@ -108,9 +107,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testConv2D(Nd4jBackend backend){
|
public void testConv2D(Nd4jBackend backend){
|
||||||
Conv2DConfig.builder().kH(2).kW(4).build();
|
Conv2DConfig.builder().kH(2).kW(4).build();
|
||||||
|
|
||||||
|
@ -171,9 +169,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testPooling2D(Nd4jBackend backend){
|
public void testPooling2D(Nd4jBackend backend){
|
||||||
Pooling2DConfig.builder().kH(2).kW(4).build();
|
Pooling2DConfig.builder().kH(2).kW(4).build();
|
||||||
|
|
||||||
|
@ -234,9 +231,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testDeConv3D(Nd4jBackend backend){
|
public void testDeConv3D(Nd4jBackend backend){
|
||||||
DeConv3DConfig.builder().kH(2).kW(4).kD(3).build();
|
DeConv3DConfig.builder().kH(2).kW(4).kD(3).build();
|
||||||
|
|
||||||
|
@ -325,9 +321,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testConv3D(Nd4jBackend backend){
|
public void testConv3D(Nd4jBackend backend){
|
||||||
Conv3DConfig.builder().kH(2).kW(4).kD(3).build();
|
Conv3DConfig.builder().kH(2).kW(4).kD(3).build();
|
||||||
|
|
||||||
|
@ -418,9 +413,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testPooling3D(Nd4jBackend backend){
|
public void testPooling3D(Nd4jBackend backend){
|
||||||
Pooling3DConfig.builder().kH(2).kW(4).kD(3).build();
|
Pooling3DConfig.builder().kH(2).kW(4).kD(3).build();
|
||||||
|
|
||||||
|
@ -509,9 +503,8 @@ public class ConvConfigTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testConv1D(){
|
public void testConv1D(){
|
||||||
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();
|
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();
|
||||||
|
|
||||||
|
|
|
@ -50,9 +50,8 @@ public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEye(Nd4jBackend backend){
|
public void testEye(Nd4jBackend backend){
|
||||||
//OpValidationSuite.ignoreFailing();
|
//OpValidationSuite.ignoreFailing();
|
||||||
INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3});
|
INDArray arr = Nd4j.create(new double[]{1, 0, 0, 0, 1, 0}, new int[]{2, 3});
|
||||||
|
@ -68,9 +67,8 @@ public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(expOut, result.eval());
|
assertEquals(expOut, result.eval());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEyeShape(Nd4jBackend backend){
|
public void testEyeShape(Nd4jBackend backend){
|
||||||
val dco = DynamicCustomOp.builder("eye")
|
val dco = DynamicCustomOp.builder("eye")
|
||||||
.addIntegerArguments(3,3)
|
.addIntegerArguments(3,3)
|
||||||
|
@ -82,9 +80,8 @@ public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(new long[]{3,3}, list.get(0).getShape());
|
assertArrayEquals(new long[]{3,3}, list.get(0).getShape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testExecutionDifferentShapesTransform(Nd4jBackend backend){
|
public void testExecutionDifferentShapesTransform(Nd4jBackend backend){
|
||||||
OpValidationSuite.ignoreFailing();
|
OpValidationSuite.ignoreFailing();
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -105,9 +102,8 @@ public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp, out2);
|
assertEquals(exp, out2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDropout(Nd4jBackend backend) {
|
public void testDropout(Nd4jBackend backend) {
|
||||||
OpValidationSuite.ignoreFailing();
|
OpValidationSuite.ignoreFailing();
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -120,9 +116,8 @@ public class FailingSameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(new long[]{2, 2}, res.getShape());
|
assertArrayEquals(new long[]{2, 2}, res.getShape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testExecutionDifferentShapesDynamicCustom(Nd4jBackend backend){
|
public void testExecutionDifferentShapesDynamicCustom(Nd4jBackend backend){
|
||||||
OpValidationSuite.ignoreFailing();
|
OpValidationSuite.ignoreFailing();
|
||||||
|
|
||||||
|
|
|
@ -67,9 +67,7 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertNotNull;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
|
@ -82,9 +80,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
|
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
|
||||||
|
@ -139,9 +136,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(sd.getLossVariables().size(), fg.lossVariablesLength());
|
assertEquals(sd.getLossVariables().size(), fg.lossVariablesLength());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||||
for( int i = 0; i < 10; i++ ) {
|
for( int i = 0; i < 10; i++ ) {
|
||||||
for(boolean execFirst : new boolean[]{false, true}) {
|
for(boolean execFirst : new boolean[]{false, true}) {
|
||||||
|
@ -270,9 +266,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
//Ensure 2 things:
|
//Ensure 2 things:
|
||||||
|
@ -356,9 +351,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void pooling3DSerialization(Nd4jBackend backend){
|
public void pooling3DSerialization(Nd4jBackend backend){
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
@ -378,9 +372,8 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
deserialized.getVariableOutputOp("pool").getClass());
|
deserialized.getVariableOutputOp("pool").getClass());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void pooling3DSerialization2(Nd4jBackend backend){
|
public void pooling3DSerialization2(Nd4jBackend backend){
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
|
|
@ -52,9 +52,8 @@ public class GraphTransformUtilTests extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBasic(Nd4jBackend backend){
|
public void testBasic(Nd4jBackend backend){
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -93,9 +92,8 @@ public class GraphTransformUtilTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(0, sg2.getChildNodes().size());
|
assertEquals(0, sg2.getChildNodes().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSubgraphReplace1(Nd4jBackend backend){
|
public void testSubgraphReplace1(Nd4jBackend backend){
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
|
@ -42,9 +42,8 @@ public class MemoryMgrTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testArrayReuseTooLarge(Nd4jBackend backend) throws Exception {
|
public void testArrayReuseTooLarge(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
|
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
|
||||||
|
@ -116,9 +115,8 @@ public class MemoryMgrTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(0, mmgr.getLruCacheValues().size());
|
assertEquals(0, mmgr.getLruCacheValues().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testManyArrays(Nd4jBackend backend){
|
public void testManyArrays(Nd4jBackend backend){
|
||||||
|
|
||||||
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
|
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
|
||||||
|
|
|
@ -45,9 +45,8 @@ public class NameScopeTests extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVariableNameScopesBasic(Nd4jBackend backend) {
|
public void testVariableNameScopesBasic(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -73,9 +72,8 @@ public class NameScopeTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOpFieldsAndNames(Nd4jBackend backend) {
|
public void testOpFieldsAndNames(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -153,9 +151,8 @@ public class NameScopeTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNoNesting(Nd4jBackend backend) {
|
public void testNoNesting(Nd4jBackend backend) {
|
||||||
SameDiff SD = SameDiff.create();
|
SameDiff SD = SameDiff.create();
|
||||||
|
|
||||||
|
@ -172,9 +169,8 @@ public class NameScopeTests extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(SD.variableMap().containsKey("test/argmax"),"Var with name test/argmax exists");
|
assertTrue(SD.variableMap().containsKey("test/argmax"),"Var with name test/argmax exists");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNoTesting2(Nd4jBackend backend) {
|
public void testNoTesting2(Nd4jBackend backend) {
|
||||||
SameDiff SD = SameDiff.create();
|
SameDiff SD = SameDiff.create();
|
||||||
|
|
||||||
|
|
|
@ -49,9 +49,8 @@ public class SameDiffMultiThreadTests extends BaseND4JTest {
|
||||||
return Long.MAX_VALUE;
|
return Long.MAX_VALUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimple(Nd4jBackend backend) throws Exception {
|
public void testSimple(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
int nThreads = 4;
|
int nThreads = 4;
|
||||||
|
|
|
@ -36,9 +36,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
public class SameDiffOutputTest extends BaseNd4jTestWithBackends {
|
public class SameDiffOutputTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void outputTest(Nd4jBackend backend){
|
public void outputTest(Nd4jBackend backend){
|
||||||
DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10));
|
DataSet data = new DataSet(Nd4j.zeros(10, 10), Nd4j.zeros(10, 10));
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
|
@ -33,8 +33,6 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertNotNull;
|
|
||||||
import static junit.framework.TestCase.assertNull;
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
|
public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
|
||||||
|
@ -45,9 +43,8 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSpecifiedLoss1(Nd4jBackend backend) {
|
public void testSpecifiedLoss1(Nd4jBackend backend) {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4);
|
SDVariable ph1 = sd.var("ph", DataType.FLOAT, 3, 4);
|
||||||
|
@ -68,11 +65,10 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
|
||||||
assertNotNull(ph1.gradient());
|
assertNotNull(ph1.gradient());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSpecifiedLoss2(Nd4jBackend backend) {
|
public void testSpecifiedLoss2(Nd4jBackend backend) {
|
||||||
for( int i=0; i<2; i++ ) {
|
for( int i = 0; i < 2; i++) {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4);
|
SDVariable ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4);
|
||||||
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 5));
|
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 5));
|
||||||
|
@ -111,7 +107,7 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
for(String s : new String[]{"w", "b", badd.name(), add.name(), "l1", "l2"}){
|
for(String s : new String[]{"w", "b", badd.name(), add.name(), "l1", "l2"}){
|
||||||
SDVariable gradVar = sd.getVariable(s).gradient();
|
SDVariable gradVar = sd.getVariable(s).gradient();
|
||||||
assertNotNull(s, gradVar);
|
assertNotNull(gradVar,s);
|
||||||
}
|
}
|
||||||
//Unused:
|
//Unused:
|
||||||
assertFalse(shape.hasGradient());
|
assertFalse(shape.hasGradient());
|
||||||
|
@ -123,9 +119,8 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTrainingDifferentLosses(Nd4jBackend backend) {
|
public void testTrainingDifferentLosses(Nd4jBackend backend) {
|
||||||
//Net with 2 losses: train on the first one, then change losses
|
//Net with 2 losses: train on the first one, then change losses
|
||||||
//Also check that if modifying via add/setLossVariables the training config changes
|
//Also check that if modifying via add/setLossVariables the training config changes
|
||||||
|
@ -154,20 +149,20 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTestWithBackends {
|
||||||
sd.setLossVariables("loss1");
|
sd.setLossVariables("loss1");
|
||||||
sd.createGradFunction();
|
sd.createGradFunction();
|
||||||
for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){
|
for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){
|
||||||
assertNotNull(v.name(), v.gradient());
|
assertNotNull(v.gradient(),v.name());
|
||||||
}
|
}
|
||||||
for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){
|
for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){
|
||||||
assertNull(v.name(), v.gradient());
|
assertNull(v.gradient(),v.name());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Now, set to other loss function
|
//Now, set to other loss function
|
||||||
sd.setLossVariables("loss2");
|
sd.setLossVariables("loss2");
|
||||||
sd.createGradFunction();
|
sd.createGradFunction();
|
||||||
for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){
|
for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){
|
||||||
assertNull(v.name(), v.gradient());
|
assertNull(v.gradient(),v.name());
|
||||||
}
|
}
|
||||||
for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){
|
for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){
|
||||||
assertNotNull(v.name(), v.gradient());
|
assertNotNull(v.gradient(),v.name());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Train the first side of the graph. The other side should remain unmodified!
|
//Train the first side of the graph. The other side should remain unmodified!
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -60,9 +60,8 @@ import org.nd4j.weightinit.impl.XavierInitScheme;
|
||||||
public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
|
public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void irisTrainingSanityCheck(Nd4jBackend backend) {
|
public void irisTrainingSanityCheck(Nd4jBackend backend) {
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
@ -134,9 +133,8 @@ public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void irisTrainingEvalTest(Nd4jBackend backend) {
|
public void irisTrainingEvalTest(Nd4jBackend backend) {
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
@ -186,9 +184,8 @@ public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void irisTrainingValidationTest(Nd4jBackend backend) {
|
public void irisTrainingValidationTest(Nd4jBackend backend) {
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
@ -243,9 +240,8 @@ public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTrainingMixedDtypes(){
|
public void testTrainingMixedDtypes(){
|
||||||
|
|
||||||
for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) {
|
for (String u : new String[]{"adam", "nesterov", "adamax", "amsgrad"}) {
|
||||||
|
@ -307,9 +303,8 @@ public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void simpleClassification(Nd4jBackend backend) {
|
public void simpleClassification(Nd4jBackend backend) {
|
||||||
double learning_rate = 0.001;
|
double learning_rate = 0.001;
|
||||||
int seed = 7;
|
int seed = 7;
|
||||||
|
@ -356,9 +351,8 @@ public class SameDiffTrainingTest extends BaseNd4jTestWithBackends {
|
||||||
History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1);
|
History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTrainingEvalVarNotReqForLoss(){
|
public void testTrainingEvalVarNotReqForLoss(){
|
||||||
//If a variable is not required for the loss - normally it won't be calculated
|
//If a variable is not required for the loss - normally it won't be calculated
|
||||||
//But we want to make sure it IS calculated here - so we can perform evaluation on it
|
//But we want to make sure it IS calculated here - so we can perform evaluation on it
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.listeners;
|
package org.nd4j.autodiff.samediff.listeners;
|
||||||
|
|
||||||
import org.junit.Assert;
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
@ -47,8 +47,9 @@ import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@ -94,9 +95,8 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
|
|
||||||
|
@ -130,9 +130,8 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(found1 && found2 && found3);
|
assertTrue(found1 && found2 && found3);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
|
|
||||||
|
@ -165,15 +164,14 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
assertEquals(5, files.length); //4 checkpoints and 1 text file (metadata)
|
assertEquals(5, files.length); //4 checkpoints and 1 text file (metadata)
|
||||||
|
|
||||||
for( int i=0; i<found.length; i++ ){
|
for( int i = 0; i < found.length; i++) {
|
||||||
assertTrue(names.get(i), found[i]);
|
assertTrue(found[i], names.get(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
SameDiff sd = getModel();
|
SameDiff sd = getModel();
|
||||||
|
@ -213,13 +211,12 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
for( int i = 0; i < found.length; i++) {
|
for( int i = 0; i < found.length; i++) {
|
||||||
assertTrue(names.get(i), found[i]);
|
assertTrue(found[i], names.get(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
SameDiff sd = getModel();
|
SameDiff sd = getModel();
|
||||||
|
@ -258,8 +255,8 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(5, cpNums.size(),cpNums.toString());
|
assertEquals(5, cpNums.size(),cpNums.toString());
|
||||||
Assert.assertTrue(cpNums.toString(), cpNums.containsAll(Arrays.asList(2, 5, 7, 8, 9)));
|
assertTrue(cpNums.containsAll(Arrays.asList(2, 5, 7, 8, 9)), cpNums.toString());
|
||||||
Assert.assertTrue(epochNums.toString(), epochNums.containsAll(Arrays.asList(5, 11, 15, 17, 19)));
|
assertTrue(epochNums.containsAll(Arrays.asList(5, 11, 15, 17, 19)), epochNums.toString());
|
||||||
|
|
||||||
assertEquals(5, l.availableCheckpoints().size());
|
assertEquals(5, l.availableCheckpoints().size());
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,9 +38,8 @@ import org.nd4j.linalg.learning.config.Adam;
|
||||||
public class ExecDebuggingListenerTest extends BaseNd4jTestWithBackends {
|
public class ExecDebuggingListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testExecDebugListener(Nd4jBackend backend) {
|
public void testExecDebugListener(Nd4jBackend backend) {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
|
@ -71,9 +71,8 @@ public class ListenerTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void irisHistoryTest(Nd4jBackend backend) {
|
public void irisHistoryTest(Nd4jBackend backend) {
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
@ -136,9 +135,8 @@ public class ListenerTest extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(acc >= 0.75,"Accuracy < 75%, was " + acc);
|
assertTrue(acc >= 0.75,"Accuracy < 75%, was " + acc);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testListenerCalls(){
|
public void testListenerCalls(){
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
||||||
|
@ -275,9 +273,8 @@ public class ListenerTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCustomListener(Nd4jBackend backend) {
|
public void testCustomListener(Nd4jBackend backend) {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
|
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
|
||||||
|
|
|
@ -57,9 +57,8 @@ public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -108,25 +107,22 @@ public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testLoadTfProfile(){
|
public void testLoadTfProfile(){
|
||||||
File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json");
|
File f = new File("C:\\Temp\\sd_profiler\\tf_profile.json");
|
||||||
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
ProfileAnalyzer.summarizeProfile(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testLoadTfProfileDir(){
|
public void testLoadTfProfileDir(){
|
||||||
File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles");
|
File f = new File("C:\\Temp\\sd_profiler\\tf_multiple_profiles");
|
||||||
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testLoadTfProfileDir2(){
|
public void testLoadTfProfileDir2(){
|
||||||
File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0");
|
File f = new File("C:\\DL4J\\Git\\dl4j-dev-tools\\import-tests\\profiling\\mobilenet_v2_1.0_224_batch32_tf-1.15.0");
|
||||||
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
ProfileAnalyzer.summarizeProfileDirectory(f, ProfileAnalyzer.ProfileFormat.TENSORFLOW);
|
||||||
|
|
|
@ -79,9 +79,8 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.getRandom().setSeed(123);
|
Nd4j.getRandom().setSeed(123);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
|
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
|
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
|
||||||
|
@ -185,9 +184,8 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
|
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
File f = new File(dir, "temp.bin");
|
File f = new File(dir, "temp.bin");
|
||||||
|
|
|
@ -63,9 +63,8 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -101,9 +100,8 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(new long[]{150, 3}, out.shape());
|
assertArrayEquals(new long[]{150, 3}, out.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
|
@ -194,9 +192,8 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
SameDiff sd1 = getSimpleNet();
|
SameDiff sd1 = getSimpleNet();
|
||||||
|
|
|
@ -40,9 +40,8 @@ public class CustomEvaluationTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void customEvalTest(Nd4jBackend backend){
|
public void customEvalTest(Nd4jBackend backend){
|
||||||
CustomEvaluation accuracyEval = new CustomEvaluation<>(
|
CustomEvaluation accuracyEval = new CustomEvaluation<>(
|
||||||
(labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)),
|
(labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)),
|
||||||
|
|
|
@ -45,9 +45,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testEmptyEvaluation (Nd4jBackend backend) {
|
public void testEmptyEvaluation (Nd4jBackend backend) {
|
||||||
Evaluation e = new Evaluation();
|
Evaluation e = new Evaluation();
|
||||||
System.out.println(e.stats());
|
System.out.println(e.stats());
|
||||||
|
@ -62,9 +61,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEmptyRegressionEvaluation (Nd4jBackend backend) {
|
public void testEmptyRegressionEvaluation (Nd4jBackend backend) {
|
||||||
RegressionEvaluation re = new RegressionEvaluation();
|
RegressionEvaluation re = new RegressionEvaluation();
|
||||||
re.stats();
|
re.stats();
|
||||||
|
@ -78,9 +76,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEmptyEvaluationBinary(Nd4jBackend backend) {
|
public void testEmptyEvaluationBinary(Nd4jBackend backend) {
|
||||||
EvaluationBinary eb = new EvaluationBinary();
|
EvaluationBinary eb = new EvaluationBinary();
|
||||||
eb.stats();
|
eb.stats();
|
||||||
|
@ -95,9 +92,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEmptyROC(Nd4jBackend backend) {
|
public void testEmptyROC(Nd4jBackend backend) {
|
||||||
ROC roc = new ROC();
|
ROC roc = new ROC();
|
||||||
roc.stats();
|
roc.stats();
|
||||||
|
@ -112,9 +108,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEmptyROCBinary(Nd4jBackend backend) {
|
public void testEmptyROCBinary(Nd4jBackend backend) {
|
||||||
ROCBinary rb = new ROCBinary();
|
ROCBinary rb = new ROCBinary();
|
||||||
rb.stats();
|
rb.stats();
|
||||||
|
@ -129,9 +124,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEmptyROCMultiClass(Nd4jBackend backend) {
|
public void testEmptyROCMultiClass(Nd4jBackend backend) {
|
||||||
ROCMultiClass r = new ROCMultiClass();
|
ROCMultiClass r = new ROCMultiClass();
|
||||||
r.stats();
|
r.stats();
|
||||||
|
@ -146,9 +140,8 @@ public class EmptyEvaluationTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEmptyEvaluationCalibration(Nd4jBackend backend) {
|
public void testEmptyEvaluationCalibration(Nd4jBackend backend) {
|
||||||
EvaluationCalibration ec = new EvaluationCalibration();
|
EvaluationCalibration ec = new EvaluationCalibration();
|
||||||
ec.stats();
|
ec.stats();
|
||||||
|
|
|
@ -46,9 +46,8 @@ public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationCustomBinaryThreshold(Nd4jBackend backend) {
|
public void testEvaluationCustomBinaryThreshold(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -114,9 +113,8 @@ public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(ex2.getConfusionMatrix(), e025v2.getConfusionMatrix());
|
assertEquals(ex2.getConfusionMatrix(), e025v2.getConfusionMatrix());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationCostArray(Nd4jBackend backend) {
|
public void testEvaluationCostArray(Nd4jBackend backend) {
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,9 +162,8 @@ public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(1.0, e2.accuracy(), 1e-6);
|
assertEquals(1.0, e2.accuracy(), 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationBinaryCustomThreshold(Nd4jBackend backend) {
|
public void testEvaluationBinaryCustomThreshold(Nd4jBackend backend) {
|
||||||
|
|
||||||
//Sanity check: same results for 0.5 threshold vs. default (no threshold)
|
//Sanity check: same results for 0.5 threshold vs. default (no threshold)
|
||||||
|
|
|
@ -39,9 +39,7 @@ import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertNull;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
||||||
|
|
||||||
|
|
||||||
public class EvalJsonTest extends BaseNd4jTestWithBackends {
|
public class EvalJsonTest extends BaseNd4jTestWithBackends {
|
||||||
|
@ -52,9 +50,8 @@ public class EvalJsonTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSerdeEmpty(Nd4jBackend backend) {
|
public void testSerdeEmpty(Nd4jBackend backend) {
|
||||||
boolean print = false;
|
boolean print = false;
|
||||||
|
|
||||||
|
@ -74,9 +71,8 @@ public class EvalJsonTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testSerde(Nd4jBackend backend) {
|
public void testSerde(Nd4jBackend backend) {
|
||||||
boolean print = false;
|
boolean print = false;
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -124,9 +120,8 @@ public class EvalJsonTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testSerdeExactRoc(Nd4jBackend backend) {
|
public void testSerdeExactRoc(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
boolean print = false;
|
boolean print = false;
|
||||||
|
@ -204,9 +199,8 @@ public class EvalJsonTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testJsonYamlCurves(Nd4jBackend backend) {
|
public void testJsonYamlCurves(Nd4jBackend backend) {
|
||||||
ROC roc = new ROC(0);
|
ROC roc = new ROC(0);
|
||||||
|
|
||||||
|
@ -258,9 +252,8 @@ public class EvalJsonTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testJsonWithCustomThreshold(Nd4jBackend backend) {
|
public void testJsonWithCustomThreshold(Nd4jBackend backend) {
|
||||||
|
|
||||||
//Evaluation - binary threshold
|
//Evaluation - binary threshold
|
||||||
|
|
|
@ -50,9 +50,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEval(Nd4jBackend backend) {
|
public void testEval(Nd4jBackend backend) {
|
||||||
int classNum = 5;
|
int classNum = 5;
|
||||||
Evaluation eval = new Evaluation (classNum);
|
Evaluation eval = new Evaluation (classNum);
|
||||||
|
@ -91,9 +90,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(0.5, eval.accuracy(), 0);
|
assertEquals(0.5, eval.accuracy(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEval2(Nd4jBackend backend) {
|
public void testEval2(Nd4jBackend backend) {
|
||||||
|
|
||||||
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
||||||
|
@ -152,9 +150,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStringListLabels(Nd4jBackend backend) {
|
public void testStringListLabels(Nd4jBackend backend) {
|
||||||
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
||||||
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
||||||
|
@ -171,9 +168,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStringHashLabels(Nd4jBackend backend) {
|
public void testStringHashLabels(Nd4jBackend backend) {
|
||||||
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
||||||
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
|
||||||
|
@ -190,9 +186,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvalMasking(Nd4jBackend backend) {
|
public void testEvalMasking(Nd4jBackend backend) {
|
||||||
int miniBatch = 5;
|
int miniBatch = 5;
|
||||||
int nOut = 3;
|
int nOut = 3;
|
||||||
|
@ -259,9 +254,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testFalsePerfectRecall(Nd4jBackend backend) {
|
public void testFalsePerfectRecall(Nd4jBackend backend) {
|
||||||
int testSize = 100;
|
int testSize = 100;
|
||||||
int numClasses = 5;
|
int numClasses = 5;
|
||||||
|
@ -294,9 +288,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
assertNotEquals(1.0, eval.recall());
|
assertNotEquals(1.0, eval.recall());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationMerging(Nd4jBackend backend) {
|
public void testEvaluationMerging(Nd4jBackend backend) {
|
||||||
|
|
||||||
int nRows = 20;
|
int nRows = 20;
|
||||||
|
@ -370,9 +363,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSingleClassBinaryClassification(Nd4jBackend backend) {
|
public void testSingleClassBinaryClassification(Nd4jBackend backend) {
|
||||||
|
|
||||||
Evaluation eval = new Evaluation(1);
|
Evaluation eval = new Evaluation(1);
|
||||||
|
@ -401,9 +393,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvalInvalid(Nd4jBackend backend) {
|
public void testEvalInvalid(Nd4jBackend backend) {
|
||||||
Evaluation e = new Evaluation(5);
|
Evaluation e = new Evaluation(5);
|
||||||
e.eval(0, 1);
|
e.eval(0, 1);
|
||||||
|
@ -416,9 +407,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
assertFalse(e.stats().contains("\uFFFD"));
|
assertFalse(e.stats().contains("\uFFFD"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvalMethods(Nd4jBackend backend) {
|
public void testEvalMethods(Nd4jBackend backend) {
|
||||||
//Check eval(int,int) vs. eval(INDArray,INDArray)
|
//Check eval(int,int) vs. eval(INDArray,INDArray)
|
||||||
|
|
||||||
|
@ -461,9 +451,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTopNAccuracy(Nd4jBackend backend) {
|
public void testTopNAccuracy(Nd4jBackend backend) {
|
||||||
|
|
||||||
Evaluation e = new Evaluation(null, 3);
|
Evaluation e = new Evaluation(null, 3);
|
||||||
|
@ -524,9 +513,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTopNAccuracyMerging(Nd4jBackend backend) {
|
public void testTopNAccuracyMerging(Nd4jBackend backend) {
|
||||||
|
|
||||||
Evaluation e1 = new Evaluation(null, 3);
|
Evaluation e1 = new Evaluation(null, 3);
|
||||||
|
@ -574,9 +562,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(6.0 / 8, e1.topNAccuracy(), 1e-6);
|
assertEquals(6.0 / 8, e1.topNAccuracy(), 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBinaryCase(Nd4jBackend backend) {
|
public void testBinaryCase(Nd4jBackend backend) {
|
||||||
INDArray ones10 = Nd4j.ones(10, 1);
|
INDArray ones10 = Nd4j.ones(10, 1);
|
||||||
INDArray ones4 = Nd4j.ones(4, 1);
|
INDArray ones4 = Nd4j.ones(4, 1);
|
||||||
|
@ -605,9 +592,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(2, (int) e.truePositives().get(0));
|
assertEquals(2, (int) e.truePositives().get(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testF1FBeta_MicroMacroAveraging(Nd4jBackend backend) {
|
public void testF1FBeta_MicroMacroAveraging(Nd4jBackend backend) {
|
||||||
//Confusion matrix: rows = actual, columns = predicted
|
//Confusion matrix: rows = actual, columns = predicted
|
||||||
//[3, 1, 0]
|
//[3, 1, 0]
|
||||||
|
@ -748,9 +734,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConfusionMatrixStats(Nd4jBackend backend) {
|
public void testConfusionMatrixStats(Nd4jBackend backend) {
|
||||||
|
|
||||||
Evaluation e = new Evaluation();
|
Evaluation e = new Evaluation();
|
||||||
|
@ -771,9 +756,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvalBinaryMetrics(){
|
public void testEvalBinaryMetrics(){
|
||||||
|
|
||||||
Evaluation ePosClass1_nOut2 = new Evaluation(2, 1);
|
Evaluation ePosClass1_nOut2 = new Evaluation(2, 1);
|
||||||
|
@ -894,9 +878,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConfusionMatrixString(){
|
public void testConfusionMatrixString(){
|
||||||
|
|
||||||
Evaluation e = new Evaluation(Arrays.asList("a","b","c"));
|
Evaluation e = new Evaluation(Arrays.asList("a","b","c"));
|
||||||
|
@ -946,9 +929,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
e.stats(false, true);
|
e.stats(false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationNaNs(){
|
public void testEvaluationNaNs(){
|
||||||
|
|
||||||
Evaluation e = new Evaluation();
|
Evaluation e = new Evaluation();
|
||||||
|
@ -963,9 +945,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSegmentation(){
|
public void testSegmentation(){
|
||||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -1059,9 +1040,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLabelReset(){
|
public void testLabelReset(){
|
||||||
|
|
||||||
Map<Integer,String> m = new HashMap<>();
|
Map<Integer,String> m = new HashMap<>();
|
||||||
|
@ -1094,9 +1074,8 @@ public class EvalTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(s1, s2);
|
assertEquals(s1, s2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvalStatsBinaryCase(){
|
public void testEvalStatsBinaryCase(){
|
||||||
//Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case
|
//Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case
|
||||||
|
|
||||||
|
|
|
@ -48,9 +48,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationBinary(Nd4jBackend backend) {
|
public void testEvaluationBinary(Nd4jBackend backend) {
|
||||||
//Compare EvaluationBinary to Evaluation class
|
//Compare EvaluationBinary to Evaluation class
|
||||||
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
||||||
|
@ -136,9 +135,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationBinaryMerging(Nd4jBackend backend) {
|
public void testEvaluationBinaryMerging(Nd4jBackend backend) {
|
||||||
int nOut = 4;
|
int nOut = 4;
|
||||||
int[] shape1 = {30, nOut};
|
int[] shape1 = {30, nOut};
|
||||||
|
@ -165,9 +163,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(eb.stats(), eb1.stats());
|
assertEquals(eb.stats(), eb1.stats());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationBinaryPerOutputMasking(Nd4jBackend backend) {
|
public void testEvaluationBinaryPerOutputMasking(Nd4jBackend backend) {
|
||||||
|
|
||||||
//Provide a mask array: "ignore" the masked steps
|
//Provide a mask array: "ignore" the masked steps
|
||||||
|
@ -210,9 +207,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(1, eb.falseNegatives(2));
|
assertEquals(1, eb.falseNegatives(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTimeSeriesEval(Nd4jBackend backend) {
|
public void testTimeSeriesEval(Nd4jBackend backend) {
|
||||||
|
|
||||||
int[] shape = {2, 4, 3};
|
int[] shape = {2, 4, 3};
|
||||||
|
@ -236,9 +232,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(eb2.stats(), eb1.stats());
|
assertEquals(eb2.stats(), eb1.stats());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationBinaryWithROC(Nd4jBackend backend) {
|
public void testEvaluationBinaryWithROC(Nd4jBackend backend) {
|
||||||
//Simple test for nested ROCBinary in EvaluationBinary
|
//Simple test for nested ROCBinary in EvaluationBinary
|
||||||
|
|
||||||
|
@ -255,9 +250,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationBinary3d(Nd4jBackend backend) {
|
public void testEvaluationBinary3d(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
|
@ -291,9 +285,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationBinary4d(Nd4jBackend backend) {
|
public void testEvaluationBinary4d(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
|
@ -327,9 +320,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationBinary3dMasking(Nd4jBackend backend) {
|
public void testEvaluationBinary3dMasking(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||||
|
@ -390,9 +382,8 @@ public class EvaluationBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvaluationBinary4dMasking(Nd4jBackend backend) {
|
public void testEvaluationBinary4dMasking(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
|
|
|
@ -49,9 +49,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testReliabilityDiagram (Nd4jBackend backend) {
|
public void testReliabilityDiagram (Nd4jBackend backend) {
|
||||||
|
|
||||||
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
||||||
|
@ -143,9 +142,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testLabelAndPredictionCounts (Nd4jBackend backend) {
|
public void testLabelAndPredictionCounts (Nd4jBackend backend) {
|
||||||
|
|
||||||
int minibatch = 50;
|
int minibatch = 50;
|
||||||
|
@ -173,9 +171,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(expPredictionCount, ec.getPredictionCountsEachClass());
|
assertArrayEquals(expPredictionCount, ec.getPredictionCountsEachClass());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testResidualPlots (Nd4jBackend backend) {
|
public void testResidualPlots (Nd4jBackend backend) {
|
||||||
|
|
||||||
int minibatch = 50;
|
int minibatch = 50;
|
||||||
|
@ -276,9 +273,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testSegmentation(){
|
public void testSegmentation(){
|
||||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -372,9 +368,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testEvaluationCalibration3d (Nd4jBackend backend) {
|
public void testEvaluationCalibration3d (Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
|
@ -406,9 +401,8 @@ public class EvaluationCalibrationTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(e2d.stats(), e3d.stats());
|
assertEquals(e2d.stats(), e3d.stats());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testEvaluationCalibration3dMasking (Nd4jBackend backend) {
|
public void testEvaluationCalibration3dMasking (Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||||
|
|
|
@ -46,9 +46,8 @@ public class NewInstanceTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNewInstances(Nd4jBackend backend) {
|
public void testNewInstances(Nd4jBackend backend) {
|
||||||
boolean print = true;
|
boolean print = true;
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -48,9 +48,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testROCBinary(Nd4jBackend backend) {
|
public void testROCBinary(Nd4jBackend backend) {
|
||||||
//Compare ROCBinary to ROC class
|
//Compare ROCBinary to ROC class
|
||||||
|
|
||||||
|
@ -145,9 +144,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testRocBinaryMerging(Nd4jBackend backend) {
|
public void testRocBinaryMerging(Nd4jBackend backend) {
|
||||||
for (int nSteps : new int[]{30, 0}) { //0 == exact
|
for (int nSteps : new int[]{30, 0}) { //0 == exact
|
||||||
int nOut = 4;
|
int nOut = 4;
|
||||||
|
@ -177,9 +175,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
|
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
|
||||||
|
|
||||||
for (int nSteps : new int[]{30, 0}) { //0 == exact
|
for (int nSteps : new int[]{30, 0}) { //0 == exact
|
||||||
|
@ -219,9 +216,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testROCBinary3d(Nd4jBackend backend) {
|
public void testROCBinary3d(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
|
@ -255,9 +251,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testROCBinary4d(Nd4jBackend backend) {
|
public void testROCBinary4d(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
|
@ -291,9 +286,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testROCBinary3dMasking(Nd4jBackend backend) {
|
public void testROCBinary3dMasking(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||||
|
@ -354,9 +348,8 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testROCBinary4dMasking(Nd4jBackend backend) {
|
public void testROCBinary4dMasking(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
|
|
|
@ -82,9 +82,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
expFPR.put(10 / 10.0, 0.0 / totalNegatives);
|
expFPR.put(10 / 10.0, 0.0 / totalNegatives);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testRocBasic(Nd4jBackend backend) {
|
public void testRocBasic(Nd4jBackend backend) {
|
||||||
//2 outputs here - probability distribution over classes (softmax)
|
//2 outputs here - probability distribution over classes (softmax)
|
||||||
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||||
|
@ -127,9 +126,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(1.0, auc, 1e-6);
|
assertEquals(1.0, auc, 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testRocBasicSingleClass(Nd4jBackend backend) {
|
public void testRocBasicSingleClass(Nd4jBackend backend) {
|
||||||
//1 output here - single probability value (sigmoid)
|
//1 output here - single probability value (sigmoid)
|
||||||
|
|
||||||
|
@ -167,9 +165,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testRoc(Nd4jBackend backend) {
|
public void testRoc(Nd4jBackend backend) {
|
||||||
//Previous tests allowed for a perfect classifier with right threshold...
|
//Previous tests allowed for a perfect classifier with right threshold...
|
||||||
|
|
||||||
|
@ -254,9 +251,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
|
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
|
||||||
//Same as first test...
|
//Same as first test...
|
||||||
|
|
||||||
|
@ -303,9 +299,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
|
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
|
||||||
//2 outputs here - probability distribution over classes (softmax)
|
//2 outputs here - probability distribution over classes (softmax)
|
||||||
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||||
|
@ -355,9 +350,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
|
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -387,9 +381,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testCompare2Vs3Classes(Nd4jBackend backend) {
|
public void testCompare2Vs3Classes(Nd4jBackend backend) {
|
||||||
|
|
||||||
//ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together...
|
//ROC multi-class: 2 vs. 3 classes should be the same, if we add two of the classes together...
|
||||||
|
@ -438,9 +431,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testROCMerging(Nd4jBackend backend) {
|
public void testROCMerging(Nd4jBackend backend) {
|
||||||
int nArrays = 10;
|
int nArrays = 10;
|
||||||
int minibatch = 64;
|
int minibatch = 64;
|
||||||
|
@ -485,9 +477,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testROCMerging2(Nd4jBackend backend) {
|
public void testROCMerging2(Nd4jBackend backend) {
|
||||||
int nArrays = 10;
|
int nArrays = 10;
|
||||||
int minibatch = 64;
|
int minibatch = 64;
|
||||||
|
@ -532,9 +523,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testROCMultiMerging(Nd4jBackend backend) {
|
public void testROCMultiMerging(Nd4jBackend backend) {
|
||||||
|
|
||||||
int nArrays = 10;
|
int nArrays = 10;
|
||||||
|
@ -582,9 +572,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testAUCPrecisionRecall(Nd4jBackend backend) {
|
public void testAUCPrecisionRecall(Nd4jBackend backend) {
|
||||||
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
|
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
|
||||||
//at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0
|
//at threshold 0 to 0.24999: tp=2, fp=1, fn=0, tn=0 prec=2/(2+1)=0.666, recall=2/2=1.0
|
||||||
|
@ -631,9 +620,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testRocAucExact(Nd4jBackend backend) {
|
public void testRocAucExact(Nd4jBackend backend) {
|
||||||
|
|
||||||
//Check the implementation vs. Scikitlearn
|
//Check the implementation vs. Scikitlearn
|
||||||
|
@ -796,9 +784,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
|
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
|
||||||
|
|
||||||
//Set reallocation block size to say 20, but then evaluate a 100-length array
|
//Set reallocation block size to say 20, but then evaluate a 100-length array
|
||||||
|
@ -810,9 +797,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
|
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
|
||||||
double[] threshold = new double[101];
|
double[] threshold = new double[101];
|
||||||
double[] precision = threshold;
|
double[] precision = threshold;
|
||||||
|
@ -848,9 +834,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
|
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
|
||||||
//Sanity check: values calculated from the confusion matrix should match the PR curve values
|
//Sanity check: values calculated from the confusion matrix should match the PR curve values
|
||||||
|
|
||||||
|
@ -889,9 +874,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testRocMerge(){
|
public void testRocMerge(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -935,9 +919,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(auprc, auprcAct, 1e-6);
|
assertEquals(auprc, auprcAct, 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testRocMultiMerge(){
|
public void testRocMultiMerge(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -986,9 +969,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testRocBinaryMerge(){
|
public void testRocBinaryMerge(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -1033,9 +1015,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testSegmentationBinary(){
|
public void testSegmentationBinary(){
|
||||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -1125,9 +1106,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testSegmentation(){
|
public void testSegmentation(){
|
||||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -63,9 +63,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPerfectPredictions(Nd4jBackend backend) {
|
public void testPerfectPredictions(Nd4jBackend backend) {
|
||||||
|
|
||||||
int nCols = 5;
|
int nCols = 5;
|
||||||
|
@ -92,9 +91,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testKnownValues(Nd4jBackend backend) {
|
public void testKnownValues(Nd4jBackend backend) {
|
||||||
|
|
||||||
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
||||||
|
@ -150,9 +148,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRegressionEvaluationMerging(Nd4jBackend backend) {
|
public void testRegressionEvaluationMerging(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -193,9 +190,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRegressionEvalPerOutputMasking(Nd4jBackend backend) {
|
public void testRegressionEvalPerOutputMasking(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}});
|
INDArray l = Nd4j.create(new double[][] {{1, 2, 3}, {10, 20, 30}, {-5, -10, -20}});
|
||||||
|
@ -222,9 +218,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRegressionEvalTimeSeriesSplit(){
|
public void testRegressionEvalTimeSeriesSplit(){
|
||||||
|
|
||||||
INDArray out1 = Nd4j.rand(new int[]{3, 5, 20});
|
INDArray out1 = Nd4j.rand(new int[]{3, 5, 20});
|
||||||
|
@ -246,9 +241,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(e1, e2);
|
assertEquals(e1, e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRegressionEval3d(Nd4jBackend backend) {
|
public void testRegressionEval3d(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
|
@ -280,9 +274,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRegressionEval4d(Nd4jBackend backend) {
|
public void testRegressionEval4d(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
|
@ -314,9 +307,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRegressionEval3dMasking(Nd4jBackend backend) {
|
public void testRegressionEval3dMasking(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||||
|
@ -375,9 +367,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRegressionEval4dMasking(Nd4jBackend backend) {
|
public void testRegressionEval4dMasking(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
|
|
|
@ -44,9 +44,8 @@ public class TestLegacyJsonLoading extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvalLegacyFormat(Nd4jBackend backend) throws Exception {
|
public void testEvalLegacyFormat(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile();
|
File f = new ClassPathResource("regression_testing/eval_100b/evaluation.json").getFile();
|
||||||
|
|
|
@ -60,9 +60,8 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSingleDeviceAveraging1(Nd4jBackend backend) {
|
public void testSingleDeviceAveraging1(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0);
|
INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0);
|
||||||
INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0);
|
INDArray array2 = Nd4j.valueArrayOf(LENGTH, 2.0);
|
||||||
|
@ -109,9 +108,8 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(arrayMean, array16);
|
assertEquals(arrayMean, array16);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSingleDeviceAveraging2(Nd4jBackend backend) {
|
public void testSingleDeviceAveraging2(Nd4jBackend backend) {
|
||||||
INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH);
|
INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH);
|
||||||
List<INDArray> arrays = new ArrayList<>();
|
List<INDArray> arrays = new ArrayList<>();
|
||||||
|
@ -128,9 +126,8 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAccumulation1(Nd4jBackend backend) {
|
public void testAccumulation1(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.create(100).assign(1.0);
|
INDArray array1 = Nd4j.create(100).assign(1.0);
|
||||||
INDArray array2 = Nd4j.create(100).assign(2.0);
|
INDArray array2 = Nd4j.create(100).assign(2.0);
|
||||||
|
@ -143,9 +140,8 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAccumulation2(Nd4jBackend backend) {
|
public void testAccumulation2(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.create(100).assign(1.0);
|
INDArray array1 = Nd4j.create(100).assign(1.0);
|
||||||
INDArray array2 = Nd4j.create(100).assign(2.0);
|
INDArray array2 = Nd4j.create(100).assign(2.0);
|
||||||
|
@ -160,9 +156,8 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAccumulation3(Nd4jBackend backend) {
|
public void testAccumulation3(Nd4jBackend backend) {
|
||||||
// we want to ensure that cuda backend is able to launch this op on cpu
|
// we want to ensure that cuda backend is able to launch this op on cpu
|
||||||
Nd4j.getAffinityManager().allowCrossDeviceAccess(false);
|
Nd4j.getAffinityManager().allowCrossDeviceAccess(false);
|
||||||
|
|
|
@ -39,9 +39,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class DataTypeTest extends BaseNd4jTestWithBackends {
|
public class DataTypeTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDataTypes(Nd4jBackend backend) throws Exception {
|
public void testDataTypes(Nd4jBackend backend) throws Exception {
|
||||||
for (val type : DataType.values()) {
|
for (val type : DataType.values()) {
|
||||||
if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type))
|
if (DataType.UTF8.equals(type) || DataType.UNKNOWN.equals(type) || DataType.COMPRESSED.equals(type))
|
||||||
|
|
|
@ -41,9 +41,8 @@ public class InputValidationTests extends BaseNd4jTestWithBackends {
|
||||||
/////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////
|
||||||
///////////////////// Broadcast Tests ///////////////////////
|
///////////////////// Broadcast Tests ///////////////////////
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInvalidColVectorOp1(Nd4jBackend backend) {
|
public void testInvalidColVectorOp1(Nd4jBackend backend) {
|
||||||
INDArray first = Nd4j.create(10, 10);
|
INDArray first = Nd4j.create(10, 10);
|
||||||
INDArray col = Nd4j.create(5, 1);
|
INDArray col = Nd4j.create(5, 1);
|
||||||
|
@ -55,9 +54,8 @@ public class InputValidationTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInvalidColVectorOp2(Nd4jBackend backend) {
|
public void testInvalidColVectorOp2(Nd4jBackend backend) {
|
||||||
INDArray first = Nd4j.create(10, 10);
|
INDArray first = Nd4j.create(10, 10);
|
||||||
INDArray col = Nd4j.create(5, 1);
|
INDArray col = Nd4j.create(5, 1);
|
||||||
|
@ -69,9 +67,8 @@ public class InputValidationTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInvalidRowVectorOp1(Nd4jBackend backend) {
|
public void testInvalidRowVectorOp1(Nd4jBackend backend) {
|
||||||
INDArray first = Nd4j.create(10, 10);
|
INDArray first = Nd4j.create(10, 10);
|
||||||
INDArray row = Nd4j.create(1, 5);
|
INDArray row = Nd4j.create(1, 5);
|
||||||
|
@ -83,9 +80,8 @@ public class InputValidationTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInvalidRowVectorOp2(Nd4jBackend backend) {
|
public void testInvalidRowVectorOp2(Nd4jBackend backend) {
|
||||||
INDArray first = Nd4j.create(10, 10);
|
INDArray first = Nd4j.create(10, 10);
|
||||||
INDArray row = Nd4j.create(1, 5);
|
INDArray row = Nd4j.create(1, 5);
|
||||||
|
|
|
@ -51,9 +51,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class LoneTest extends BaseNd4jTestWithBackends {
|
public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSoftmaxStability(Nd4jBackend backend) {
|
public void testSoftmaxStability(Nd4jBackend backend) {
|
||||||
INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose();
|
INDArray input = Nd4j.create(new double[]{-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04}).reshape(1, -1).transpose();
|
||||||
// System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
|
// System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
|
||||||
|
@ -67,9 +66,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testFlattenedView(Nd4jBackend backend) {
|
public void testFlattenedView(Nd4jBackend backend) {
|
||||||
int rows = 8;
|
int rows = 8;
|
||||||
int cols = 8;
|
int cols = 8;
|
||||||
|
@ -105,9 +103,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(fAssertion, Nd4j.toFlattened('f', first));
|
assertEquals(fAssertion, Nd4j.toFlattened('f', first));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIndexingColVec(Nd4jBackend backend) {
|
public void testIndexingColVec(Nd4jBackend backend) {
|
||||||
int elements = 5;
|
int elements = 5;
|
||||||
INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements);
|
INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements);
|
||||||
|
@ -126,9 +123,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void concatScalarVectorIssue(Nd4jBackend backend) {
|
public void concatScalarVectorIssue(Nd4jBackend backend) {
|
||||||
//A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars
|
//A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars
|
||||||
INDArray arr1 = Nd4j.create(1, 1);
|
INDArray arr1 = Nd4j.create(1, 1);
|
||||||
|
@ -138,9 +134,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(arr4.sumNumber().floatValue() <= Nd4j.EPS_THRESHOLD);
|
assertTrue(arr4.sumNumber().floatValue() <= Nd4j.EPS_THRESHOLD);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void reshapeTensorMmul(Nd4jBackend backend) {
|
public void reshapeTensorMmul(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2);
|
INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2);
|
||||||
INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2);
|
INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2);
|
||||||
|
@ -152,9 +147,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
INDArray c = Nd4j.tensorMmul(b, a, axes);
|
INDArray c = Nd4j.tensorMmul(b, a, axes);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void maskWhenMerge(Nd4jBackend backend) {
|
public void maskWhenMerge(Nd4jBackend backend) {
|
||||||
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
|
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
|
||||||
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
|
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
|
||||||
|
@ -169,9 +163,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRelu(Nd4jBackend backend) {
|
public void testRelu(Nd4jBackend backend) {
|
||||||
INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
|
INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
|
||||||
INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
|
INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
|
||||||
|
@ -197,9 +190,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(max - 1, currentArgMax);
|
assertEquals(max - 1, currentArgMax);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRPF(Nd4jBackend backend) {
|
public void testRPF(Nd4jBackend backend) {
|
||||||
val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3);
|
val array = Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12).reshape(2, 2, 3);
|
||||||
|
|
||||||
|
@ -212,9 +204,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
log.info("TAD:\n{}", tad);
|
log.info("TAD:\n{}", tad);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConcat3D_Vstack_C(Nd4jBackend backend) {
|
public void testConcat3D_Vstack_C(Nd4jBackend backend) {
|
||||||
val shape = new long[]{1, 1000, 20};
|
val shape = new long[]{1, 1000, 20};
|
||||||
|
|
||||||
|
@ -244,9 +235,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetRow1(Nd4jBackend backend) {
|
public void testGetRow1(Nd4jBackend backend) {
|
||||||
INDArray array = Nd4j.create(10000, 10000);
|
INDArray array = Nd4j.create(10000, 10000);
|
||||||
|
|
||||||
|
@ -285,9 +275,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void checkSliceofSlice(Nd4jBackend backend) {
|
public void checkSliceofSlice(Nd4jBackend backend) {
|
||||||
/*
|
/*
|
||||||
Issue 1: Slice of slice with c order and f order views are not equal
|
Issue 1: Slice of slice with c order and f order views are not equal
|
||||||
|
@ -327,9 +316,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void checkWithReshape(Nd4jBackend backend) {
|
public void checkWithReshape(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(1, 3);
|
INDArray arr = Nd4j.create(1, 3);
|
||||||
INDArray reshaped = arr.reshape('f', 3, 1);
|
INDArray reshaped = arr.reshape('f', 3, 1);
|
||||||
|
|
|
@ -38,9 +38,8 @@ public class MmulBug extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void simpleTest(Nd4jBackend backend) {
|
public void simpleTest(Nd4jBackend backend) {
|
||||||
INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}});
|
INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}});
|
||||||
|
|
||||||
|
|
|
@ -63,9 +63,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScalarOps(Nd4jBackend backend) {
|
public void testScalarOps(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
|
INDArray n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
|
||||||
assertEquals(27d, n.length(), 1e-1);
|
assertEquals(27d, n.length(), 1e-1);
|
||||||
|
@ -83,9 +82,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testColumnMmul(Nd4jBackend backend) {
|
public void testColumnMmul(Nd4jBackend backend) {
|
||||||
DataBuffer data = Nd4j.linspace(1, 10, 18, DataType.FLOAT).data();
|
DataBuffer data = Nd4j.linspace(1, 10, 18, DataType.FLOAT).data();
|
||||||
INDArray x2 = Nd4j.create(data, new long[] {2, 3, 3});
|
INDArray x2 = Nd4j.create(data, new long[] {2, 3, 3});
|
||||||
|
@ -116,9 +114,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRowVectorGemm(Nd4jBackend backend) {
|
public void testRowVectorGemm(Nd4jBackend backend) {
|
||||||
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE);
|
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1).castTo(DataType.DOUBLE);
|
||||||
INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE);
|
INDArray other = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4).castTo(DataType.DOUBLE);
|
||||||
|
@ -129,18 +126,16 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRepmat(Nd4jBackend backend) {
|
public void testRepmat(Nd4jBackend backend) {
|
||||||
INDArray rowVector = Nd4j.create(1, 4);
|
INDArray rowVector = Nd4j.create(1, 4);
|
||||||
INDArray repmat = rowVector.repmat(4, 4);
|
INDArray repmat = rowVector.repmat(4, 4);
|
||||||
assertTrue(Arrays.equals(new long[] {4, 16}, repmat.shape()));
|
assertTrue(Arrays.equals(new long[] {4, 16}, repmat.shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReadWrite() throws Exception {
|
public void testReadWrite() throws Exception {
|
||||||
INDArray write = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
INDArray write = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||||
|
@ -155,9 +150,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReadWriteDouble() throws Exception {
|
public void testReadWriteDouble() throws Exception {
|
||||||
INDArray write = Nd4j.linspace(1, 4, 4, DataType.FLOAT);
|
INDArray write = Nd4j.linspace(1, 4, 4, DataType.FLOAT);
|
||||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||||
|
@ -173,9 +167,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiThreading() throws Exception {
|
public void testMultiThreading() throws Exception {
|
||||||
ExecutorService ex = ExecutorServiceProvider.getExecutorService();
|
ExecutorService ex = ExecutorServiceProvider.getExecutorService();
|
||||||
|
|
||||||
|
@ -195,9 +188,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBroadcastingGenerated(Nd4jBackend backend) {
|
public void testBroadcastingGenerated(Nd4jBackend backend) {
|
||||||
int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10);
|
int[][] broadcastShape = NDArrayCreationUtil.getRandomBroadCastShape(7, 6, 10);
|
||||||
List<List<Pair<INDArray, String>>> broadCastList = new ArrayList<>(broadcastShape.length);
|
List<List<Pair<INDArray, String>>> broadCastList = new ArrayList<>(broadcastShape.length);
|
||||||
|
@ -222,9 +214,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBroadCasting(Nd4jBackend backend) {
|
public void testBroadCasting(Nd4jBackend backend) {
|
||||||
INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE);
|
INDArray first = Nd4j.arange(0, 3).reshape(3, 1).castTo(DataType.DOUBLE);
|
||||||
INDArray ret = first.broadcast(3, 4);
|
INDArray ret = first.broadcast(3, 4);
|
||||||
|
@ -237,18 +228,16 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOneTensor(Nd4jBackend backend) {
|
public void testOneTensor(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1);
|
INDArray arr = Nd4j.ones(1, 1, 1, 1, 1, 1, 1);
|
||||||
INDArray matrixToBroadcast = Nd4j.ones(1, 1);
|
INDArray matrixToBroadcast = Nd4j.ones(1, 1);
|
||||||
assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr);
|
assertEquals(matrixToBroadcast.broadcast(arr.shape()), arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSortWithIndicesDescending(Nd4jBackend backend) {
|
public void testSortWithIndicesDescending(Nd4jBackend backend) {
|
||||||
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
//indices,data
|
//indices,data
|
||||||
|
@ -259,9 +248,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(shouldIndex, sorted[0],getFailureMessage());
|
assertEquals(shouldIndex, sorted[0],getFailureMessage());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSortDeadlock(Nd4jBackend backend) {
|
public void testSortDeadlock(Nd4jBackend backend) {
|
||||||
val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768);
|
val toSort = Nd4j.linspace(DataType.DOUBLE, 1, 32*768, 1).reshape(32, 768);
|
||||||
|
|
||||||
|
@ -269,9 +257,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSortWithIndices(Nd4jBackend backend) {
|
public void testSortWithIndices(Nd4jBackend backend) {
|
||||||
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray toSort = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
//indices,data
|
//indices,data
|
||||||
|
@ -282,18 +269,16 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(shouldIndex, sorted[0],getFailureMessage());
|
assertEquals(shouldIndex, sorted[0],getFailureMessage());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNd4jSortScalar(Nd4jBackend backend) {
|
public void testNd4jSortScalar(Nd4jBackend backend) {
|
||||||
INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1);
|
INDArray linspace = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(1, -1);
|
||||||
INDArray sorted = Nd4j.sort(linspace, 1, false);
|
INDArray sorted = Nd4j.sort(linspace, 1, false);
|
||||||
// System.out.println(sorted);
|
// System.out.println(sorted);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSwapAxesFortranOrder(Nd4jBackend backend) {
|
public void testSwapAxesFortranOrder(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE);
|
INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}).castTo(DataType.DOUBLE);
|
||||||
for (int i = 0; i < n.slices(); i++) {
|
for (int i = 0; i < n.slices(); i++) {
|
||||||
|
@ -312,9 +297,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDimShuffle(Nd4jBackend backend) {
|
public void testDimShuffle(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false});
|
INDArray twoOneTwo = n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false});
|
||||||
|
@ -325,9 +309,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetVsGetScalar(Nd4jBackend backend) {
|
public void testGetVsGetScalar(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
float element = a.getFloat(0, 1);
|
float element = a.getFloat(0, 1);
|
||||||
|
@ -340,9 +323,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDivide(Nd4jBackend backend) {
|
public void testDivide(Nd4jBackend backend) {
|
||||||
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
|
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
|
||||||
INDArray div = two.div(two);
|
INDArray div = two.div(two);
|
||||||
|
@ -356,9 +338,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSigmoid(Nd4jBackend backend) {
|
public void testSigmoid(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
|
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
|
||||||
|
@ -367,9 +348,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNeg(Nd4jBackend backend) {
|
public void testNeg(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
||||||
|
@ -379,9 +359,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCosineSim(Nd4jBackend backend) {
|
public void testCosineSim(Nd4jBackend backend) {
|
||||||
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
|
@ -396,9 +375,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testExp(Nd4jBackend backend) {
|
public void testExp(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f});
|
INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f});
|
||||||
|
@ -408,9 +386,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScalar(Nd4jBackend backend) {
|
public void testScalar(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.scalar(1.0f);
|
INDArray a = Nd4j.scalar(1.0f);
|
||||||
assertEquals(true, a.isScalar());
|
assertEquals(true, a.isScalar());
|
||||||
|
@ -422,9 +399,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testWrap(Nd4jBackend backend) {
|
public void testWrap(Nd4jBackend backend) {
|
||||||
int[] shape = {2, 4};
|
int[] shape = {2, 4};
|
||||||
INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]);
|
INDArray d = Nd4j.linspace(1, 8, 8, DataType.DOUBLE).reshape(shape[0], shape[1]);
|
||||||
|
@ -449,9 +425,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(row22.columns(), 2);
|
assertEquals(row22.columns(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetRowFortran(Nd4jBackend backend) {
|
public void testGetRowFortran(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2});
|
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.FLOAT).data(), new long[] {2, 2});
|
||||||
INDArray column = Nd4j.create(new float[] {1, 3});
|
INDArray column = Nd4j.create(new float[] {1, 3});
|
||||||
|
@ -464,9 +439,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetColumnFortran(Nd4jBackend backend) {
|
public void testGetColumnFortran(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
|
INDArray n = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
|
||||||
INDArray column = Nd4j.create(new double[] {1, 2});
|
INDArray column = Nd4j.create(new double[] {1, 2});
|
||||||
|
@ -480,9 +454,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetColumns(Nd4jBackend backend) {
|
public void testGetColumns(Nd4jBackend backend) {
|
||||||
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE);
|
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3).castTo(DataType.DOUBLE);
|
||||||
// log.info("Original: {}", matrix);
|
// log.info("Original: {}", matrix);
|
||||||
|
@ -496,9 +469,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVectorInit(Nd4jBackend backend) {
|
public void testVectorInit(Nd4jBackend backend) {
|
||||||
DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data();
|
DataBuffer data = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data();
|
||||||
INDArray arr = Nd4j.create(data, new long[] {1, 4});
|
INDArray arr = Nd4j.create(data, new long[] {1, 4});
|
||||||
|
@ -511,9 +483,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAssignOffset(Nd4jBackend backend) {
|
public void testAssignOffset(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.ones(5, 5);
|
INDArray arr = Nd4j.ones(5, 5);
|
||||||
INDArray row = arr.slice(1);
|
INDArray row = arr.slice(1);
|
||||||
|
@ -521,9 +492,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(Nd4j.ones(5), row);
|
assertEquals(Nd4j.ones(5), row);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testColumns(Nd4jBackend backend) {
|
public void testColumns(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE);
|
INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE);
|
||||||
INDArray column = Nd4j.create(new double[] {1, 2, 3});
|
INDArray column = Nd4j.create(new double[] {1, 2, 3});
|
||||||
|
@ -561,9 +531,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutRow(Nd4jBackend backend) {
|
public void testPutRow(Nd4jBackend backend) {
|
||||||
INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray d = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
INDArray n = d.dup();
|
INDArray n = d.dup();
|
||||||
|
@ -622,9 +591,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInplaceTranspose(Nd4jBackend backend) {
|
public void testInplaceTranspose(Nd4jBackend backend) {
|
||||||
INDArray test = Nd4j.rand(3, 4);
|
INDArray test = Nd4j.rand(3, 4);
|
||||||
INDArray orig = test.dup();
|
INDArray orig = test.dup();
|
||||||
|
@ -639,9 +607,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMmulF(Nd4jBackend backend) {
|
public void testMmulF(Nd4jBackend backend) {
|
||||||
|
|
||||||
DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data();
|
DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data();
|
||||||
|
@ -659,9 +626,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRowsColumns(Nd4jBackend backend) {
|
public void testRowsColumns(Nd4jBackend backend) {
|
||||||
DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data();
|
DataBuffer data = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).data();
|
||||||
INDArray rows = Nd4j.create(data, new long[] {2, 3});
|
INDArray rows = Nd4j.create(data, new long[] {2, 3});
|
||||||
|
@ -677,9 +643,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTranspose(Nd4jBackend backend) {
|
public void testTranspose(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4});
|
INDArray n = Nd4j.create(Nd4j.ones(100).castTo(DataType.DOUBLE).data(), new long[] {5, 5, 4});
|
||||||
INDArray transpose = n.transpose();
|
INDArray transpose = n.transpose();
|
||||||
|
@ -707,9 +672,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAddMatrix(Nd4jBackend backend) {
|
public void testAddMatrix(Nd4jBackend backend) {
|
||||||
INDArray five = Nd4j.ones(5);
|
INDArray five = Nd4j.ones(5);
|
||||||
five.addi(five.dup());
|
five.addi(five.dup());
|
||||||
|
@ -720,9 +684,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMMul(Nd4jBackend backend) {
|
public void testMMul(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
||||||
|
|
||||||
|
@ -733,9 +696,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutSlice(Nd4jBackend backend) {
|
public void testPutSlice(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3);
|
INDArray n = Nd4j.linspace(1, 27, 27, DataType.DOUBLE).reshape(3, 3, 3);
|
||||||
INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3);
|
INDArray newSlice = Nd4j.create(DataType.DOUBLE, 3, 3);
|
||||||
|
@ -746,9 +708,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
|
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
|
||||||
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
|
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
|
||||||
linear.putScalar(new long[] {0, 1}, 1);
|
linear.putScalar(new long[] {0, 1}, 1);
|
||||||
|
@ -757,9 +718,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDim1(Nd4jBackend backend) {
|
public void testDim1(Nd4jBackend backend) {
|
||||||
INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1);
|
INDArray sum = Nd4j.linspace(1, 2, 2, DataType.DOUBLE).reshape(2, 1);
|
||||||
INDArray same = sum.dup();
|
INDArray same = sum.dup();
|
||||||
|
@ -767,9 +727,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEps(Nd4jBackend backend) {
|
public void testEps(Nd4jBackend backend) {
|
||||||
val ones = Nd4j.ones(5);
|
val ones = Nd4j.ones(5);
|
||||||
val res = Nd4j.createUninitialized(DataType.BOOL, 5);
|
val res = Nd4j.createUninitialized(DataType.BOOL, 5);
|
||||||
|
@ -777,9 +736,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLogDouble(Nd4jBackend backend) {
|
public void testLogDouble(Nd4jBackend backend) {
|
||||||
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE);
|
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).castTo(DataType.DOUBLE);
|
||||||
INDArray log = Transforms.log(linspace);
|
INDArray log = Transforms.log(linspace);
|
||||||
|
@ -787,36 +745,32 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(assertion, log);
|
assertEquals(assertion, log);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVectorSum(Nd4jBackend backend) {
|
public void testVectorSum(Nd4jBackend backend) {
|
||||||
INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
INDArray lin = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||||
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
|
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVectorSum2(Nd4jBackend backend) {
|
public void testVectorSum2(Nd4jBackend backend) {
|
||||||
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
|
assertEquals(10.0, lin.sumNumber().doubleValue(), 1e-1);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVectorSum3(Nd4jBackend backend) {
|
public void testVectorSum3(Nd4jBackend backend) {
|
||||||
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray lin = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray lin2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
assertEquals(lin, lin2);
|
assertEquals(lin, lin2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSmallSum(Nd4jBackend backend) {
|
public void testSmallSum(Nd4jBackend backend) {
|
||||||
INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007});
|
INDArray base = Nd4j.create(new double[] {5.843333333333335, 3.0540000000000007});
|
||||||
base.addi(1e-12);
|
base.addi(1e-12);
|
||||||
|
@ -827,9 +781,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPermute(Nd4jBackend backend) {
|
public void testPermute(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4});
|
INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4});
|
||||||
INDArray transpose = n.transpose();
|
INDArray transpose = n.transpose();
|
||||||
|
@ -858,9 +811,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAppendBias(Nd4jBackend backend) {
|
public void testAppendBias(Nd4jBackend backend) {
|
||||||
INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose();
|
INDArray rand = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape(1, -1).transpose();
|
||||||
INDArray test = Nd4j.appendBias(rand);
|
INDArray test = Nd4j.appendBias(rand);
|
||||||
|
@ -868,9 +820,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(assertion, test);
|
assertEquals(assertion, test);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRand(Nd4jBackend backend) {
|
public void testRand(Nd4jBackend backend) {
|
||||||
INDArray rand = Nd4j.randn(5, 5);
|
INDArray rand = Nd4j.randn(5, 5);
|
||||||
Nd4j.getDistributions().createUniform(0.4, 4).sample(5);
|
Nd4j.getDistributions().createUniform(0.4, 4).sample(5);
|
||||||
|
@ -882,9 +833,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIdentity(Nd4jBackend backend) {
|
public void testIdentity(Nd4jBackend backend) {
|
||||||
INDArray eye = Nd4j.eye(5);
|
INDArray eye = Nd4j.eye(5);
|
||||||
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
|
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
|
||||||
|
@ -895,9 +845,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testColumnVectorOpsFortran(Nd4jBackend backend) {
|
public void testColumnVectorOpsFortran(Nd4jBackend backend) {
|
||||||
INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
|
INDArray twoByTwo = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||||
INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1});
|
INDArray toAdd = Nd4j.create(new float[] {1, 2}, new long[] {2, 1});
|
||||||
|
@ -908,9 +857,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRSubi(Nd4jBackend backend) {
|
public void testRSubi(Nd4jBackend backend) {
|
||||||
INDArray n2 = Nd4j.ones(2);
|
INDArray n2 = Nd4j.ones(2);
|
||||||
INDArray n2Assertion = Nd4j.zeros(2);
|
INDArray n2Assertion = Nd4j.zeros(2);
|
||||||
|
@ -920,9 +868,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAssign(Nd4jBackend backend) {
|
public void testAssign(Nd4jBackend backend) {
|
||||||
INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
INDArray vector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
||||||
vector.assign(1);
|
vector.assign(1);
|
||||||
|
@ -939,9 +886,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(tensor, ones);
|
assertEquals(tensor, ones);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAddScalar(Nd4jBackend backend) {
|
public void testAddScalar(Nd4jBackend backend) {
|
||||||
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
|
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
|
||||||
INDArray rdiv = div.add(1);
|
INDArray rdiv = div.add(1);
|
||||||
|
@ -949,9 +895,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(answer, rdiv);
|
assertEquals(answer, rdiv);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRdivScalar(Nd4jBackend backend) {
|
public void testRdivScalar(Nd4jBackend backend) {
|
||||||
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
|
INDArray div = Nd4j.valueArrayOf(new long[] {1, 4}, 4.0);
|
||||||
INDArray rdiv = div.rdiv(1);
|
INDArray rdiv = div.rdiv(1);
|
||||||
|
@ -959,9 +904,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(rdiv, answer);
|
assertEquals(rdiv, answer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRDivi(Nd4jBackend backend) {
|
public void testRDivi(Nd4jBackend backend) {
|
||||||
INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0);
|
INDArray n2 = Nd4j.valueArrayOf(new long[] {1, 2}, 4.0);
|
||||||
INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5);
|
INDArray n2Assertion = Nd4j.valueArrayOf(new long[] {1, 2}, 0.5);
|
||||||
|
@ -971,9 +915,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNumVectorsAlongDimension(Nd4jBackend backend) {
|
public void testNumVectorsAlongDimension(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2);
|
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2);
|
||||||
assertEquals(12, arr.vectorsAlongDimension(2));
|
assertEquals(12, arr.vectorsAlongDimension(2));
|
||||||
|
@ -981,9 +924,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBroadCast(Nd4jBackend backend) {
|
public void testBroadCast(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
INDArray n = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||||
INDArray broadCasted = n.broadcast(5, 4);
|
INDArray broadCasted = n.broadcast(5, 4);
|
||||||
|
@ -1005,9 +947,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(Arrays.equals(new long[] {1, 2, 36, 36}, broadCasted3.shape()));
|
assertTrue(Arrays.equals(new long[] {1, 2, 36, 36}, broadCasted3.shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMatrix(Nd4jBackend backend) {
|
public void testMatrix(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
INDArray arr = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||||
INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2});
|
INDArray brr = Nd4j.create(new double[] {5, 6}, new long[] {2});
|
||||||
|
@ -1017,9 +958,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutRowGetRowOrdering(Nd4jBackend backend) {
|
public void testPutRowGetRowOrdering(Nd4jBackend backend) {
|
||||||
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
INDArray put = Nd4j.create(new double[] {5, 6});
|
INDArray put = Nd4j.create(new double[] {5, 6});
|
||||||
|
@ -1041,9 +981,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSumWithRow1(Nd4jBackend backend) {
|
public void testSumWithRow1(Nd4jBackend backend) {
|
||||||
//Works:
|
//Works:
|
||||||
INDArray array2d = Nd4j.ones(1, 10);
|
INDArray array2d = Nd4j.ones(1, 10);
|
||||||
|
@ -1074,9 +1013,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
array5d.sum(4); //java.lang.IllegalArgumentException: Illegal index 10000 derived from 9 with offset of 1000 and stride of 1000
|
array5d.sum(4); //java.lang.IllegalArgumentException: Illegal index 10000 derived from 9 with offset of 1000 and stride of 1000
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSumWithRow2(Nd4jBackend backend) {
|
public void testSumWithRow2(Nd4jBackend backend) {
|
||||||
//All sums in this method execute without exceptions.
|
//All sums in this method execute without exceptions.
|
||||||
INDArray array3d = Nd4j.ones(2, 10, 10);
|
INDArray array3d = Nd4j.ones(2, 10, 10);
|
||||||
|
@ -1099,9 +1037,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutRowFortran(Nd4jBackend backend) {
|
public void testPutRowFortran(Nd4jBackend backend) {
|
||||||
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE);
|
INDArray row1 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2).castTo(DataType.DOUBLE);
|
||||||
INDArray put = Nd4j.create(new double[] {5, 6});
|
INDArray put = Nd4j.create(new double[] {5, 6});
|
||||||
|
@ -1114,9 +1051,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testElementWiseOps(Nd4jBackend backend) {
|
public void testElementWiseOps(Nd4jBackend backend) {
|
||||||
INDArray n1 = Nd4j.scalar(1);
|
INDArray n1 = Nd4j.scalar(1);
|
||||||
INDArray n2 = Nd4j.scalar(2);
|
INDArray n2 = Nd4j.scalar(2);
|
||||||
|
@ -1139,9 +1075,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRollAxis(Nd4jBackend backend) {
|
public void testRollAxis(Nd4jBackend backend) {
|
||||||
INDArray toRoll = Nd4j.ones(3, 4, 5, 6);
|
INDArray toRoll = Nd4j.ones(3, 4, 5, 6);
|
||||||
assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape());
|
assertArrayEquals(new long[] {3, 6, 4, 5}, Nd4j.rollAxis(toRoll, 3, 1).shape());
|
||||||
|
@ -1163,9 +1098,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNegativeShape(Nd4jBackend backend) {
|
public void testNegativeShape(Nd4jBackend backend) {
|
||||||
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||||
INDArray reshaped = linspace.reshape(-1, 2);
|
INDArray reshaped = linspace.reshape(-1, 2);
|
||||||
|
@ -1177,9 +1111,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetColumnGetRow(Nd4jBackend backend) {
|
public void testGetColumnGetRow(Nd4jBackend backend) {
|
||||||
INDArray row = Nd4j.ones(1, 5);
|
INDArray row = Nd4j.ones(1, 5);
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < 5; i++) {
|
||||||
|
@ -1194,9 +1127,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDupAndDupWithOrder(Nd4jBackend backend) {
|
public void testDupAndDupWithOrder(Nd4jBackend backend) {
|
||||||
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
|
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
@ -1218,9 +1150,8 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testToOffsetZeroCopy(Nd4jBackend backend) {
|
public void testToOffsetZeroCopy(Nd4jBackend backend) {
|
||||||
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
|
List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123, DataType.DOUBLE);
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -69,9 +69,8 @@ public class Nd4jTestsComparisonC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
|
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
|
||||||
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||||
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
|
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
|
||||||
|
|
|
@ -71,9 +71,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
|
||||||
return 'f';
|
return 'f';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCrash(Nd4jBackend backend) {
|
public void testCrash(Nd4jBackend backend) {
|
||||||
INDArray array3d = Nd4j.ones(1, 10, 10);
|
INDArray array3d = Nd4j.ones(1, 10, 10);
|
||||||
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0);
|
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0);
|
||||||
|
@ -83,9 +82,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array4d, 0);
|
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array4d, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMmulWithOpsCommonsMath(Nd4jBackend backend) {
|
public void testMmulWithOpsCommonsMath(Nd4jBackend backend) {
|
||||||
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||||
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
|
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
|
||||||
|
@ -100,9 +98,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
|
public void testGemmWithOpsCommonsMath(Nd4jBackend backend) {
|
||||||
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||||
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
|
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
|
||||||
|
@ -158,9 +155,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemvApacheCommons(Nd4jBackend backend) {
|
public void testGemvApacheCommons(Nd4jBackend backend) {
|
||||||
|
|
||||||
int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8};
|
int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8};
|
||||||
|
@ -215,9 +211,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAddSubtractWithOpsCommonsMath(Nd4jBackend backend) {
|
public void testAddSubtractWithOpsCommonsMath(Nd4jBackend backend) {
|
||||||
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||||
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||||
|
@ -235,9 +230,8 @@ public class Nd4jTestsComparisonFortran extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMulDivOnCheckUtilMatrices(Nd4jBackend backend) {
|
public void testMulDivOnCheckUtilMatrices(Nd4jBackend backend) {
|
||||||
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||||
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||||||
|
|
|
@ -42,9 +42,8 @@ public class Nd4jTestsF extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
DataType initialType = Nd4j.dataType();
|
DataType initialType = Nd4j.dataType();
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConcat3D_Vstack_F(Nd4jBackend backend) {
|
public void testConcat3D_Vstack_F(Nd4jBackend backend) {
|
||||||
//Nd4j.getExecutioner().enableVerboseMode(true);
|
//Nd4j.getExecutioner().enableVerboseMode(true);
|
||||||
//Nd4j.getExecutioner().enableDebugMode(true);
|
//Nd4j.getExecutioner().enableDebugMode(true);
|
||||||
|
@ -76,9 +75,8 @@ public class Nd4jTestsF extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSlice_1(Nd4jBackend backend) {
|
public void testSlice_1(Nd4jBackend backend) {
|
||||||
val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1);
|
val arr = Nd4j.linspace(1,4, 4, DataType.DOUBLE).reshape(2, 2, 1);
|
||||||
val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1});
|
val exp0 = Nd4j.create(new double[]{1, 3}, new int[] {2, 1});
|
||||||
|
|
|
@ -32,15 +32,14 @@ import org.nd4j.common.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
|
||||||
public class ShufflesTests extends BaseNd4jTestWithBackends {
|
public class ShufflesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimpleShuffle1(Nd4jBackend backend) {
|
public void testSimpleShuffle1(Nd4jBackend backend) {
|
||||||
INDArray array = Nd4j.zeros(10, 10);
|
INDArray array = Nd4j.zeros(10, 10);
|
||||||
for (int x = 0; x < 10; x++) {
|
for (int x = 0; x < 10; x++) {
|
||||||
|
@ -62,9 +61,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(scanner.compareRow(array));
|
assertTrue(scanner.compareRow(array));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimpleShuffle2(Nd4jBackend backend) {
|
public void testSimpleShuffle2(Nd4jBackend backend) {
|
||||||
INDArray array = Nd4j.zeros(10, 10);
|
INDArray array = Nd4j.zeros(10, 10);
|
||||||
for (int x = 0; x < 10; x++) {
|
for (int x = 0; x < 10; x++) {
|
||||||
|
@ -79,9 +77,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(scanner.compareColumn(array));
|
assertTrue(scanner.compareColumn(array));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimpleShuffle3(Nd4jBackend backend) {
|
public void testSimpleShuffle3(Nd4jBackend backend) {
|
||||||
INDArray array = Nd4j.zeros(11, 10);
|
INDArray array = Nd4j.zeros(11, 10);
|
||||||
for (int x = 0; x < 11; x++) {
|
for (int x = 0; x < 11; x++) {
|
||||||
|
@ -97,9 +94,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(scanner.compareRow(array));
|
assertTrue(scanner.compareRow(array));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSymmetricShuffle1(Nd4jBackend backend) {
|
public void testSymmetricShuffle1(Nd4jBackend backend) {
|
||||||
INDArray features = Nd4j.zeros(10, 10);
|
INDArray features = Nd4j.zeros(10, 10);
|
||||||
INDArray labels = Nd4j.zeros(10, 3);
|
INDArray labels = Nd4j.zeros(10, 3);
|
||||||
|
@ -137,9 +133,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSymmetricShuffle2(Nd4jBackend backend) {
|
public void testSymmetricShuffle2(Nd4jBackend backend) {
|
||||||
INDArray features = Nd4j.zeros(10, 10, 20);
|
INDArray features = Nd4j.zeros(10, 10, 20);
|
||||||
INDArray labels = Nd4j.zeros(10, 10, 3);
|
INDArray labels = Nd4j.zeros(10, 10, 3);
|
||||||
|
@ -177,9 +172,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSymmetricShuffle3(Nd4jBackend backend) {
|
public void testSymmetricShuffle3(Nd4jBackend backend) {
|
||||||
INDArray features = Nd4j.zeros(10, 10, 20);
|
INDArray features = Nd4j.zeros(10, 10, 20);
|
||||||
INDArray featuresMask = Nd4j.zeros(10, 20);
|
INDArray featuresMask = Nd4j.zeros(10, 20);
|
||||||
|
@ -244,9 +238,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
|
||||||
* There's SMALL chance this test will randomly fail, since spread isn't too big
|
* There's SMALL chance this test will randomly fail, since spread isn't too big
|
||||||
* @throws Exception
|
* @throws Exception
|
||||||
*/
|
*/
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testHalfVectors1(Nd4jBackend backend) {
|
public void testHalfVectors1(Nd4jBackend backend) {
|
||||||
int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20);
|
int[] array1 = ArrayUtil.buildHalfVector(new Random(12), 20);
|
||||||
int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20);
|
int[] array2 = ArrayUtil.buildHalfVector(new Random(75), 20);
|
||||||
|
@ -267,9 +260,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInterleavedVector1(Nd4jBackend backend) {
|
public void testInterleavedVector1(Nd4jBackend backend) {
|
||||||
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20);
|
int[] array1 = ArrayUtil.buildInterleavedVector(new Random(12), 20);
|
||||||
int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20);
|
int[] array2 = ArrayUtil.buildInterleavedVector(new Random(75), 20);
|
||||||
|
@ -290,9 +282,8 @@ public class ShufflesTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInterleavedVector3(Nd4jBackend backend) {
|
public void testInterleavedVector3(Nd4jBackend backend) {
|
||||||
for (int e = 0; e < 1000; e++) {
|
for (int e = 0; e < 1000; e++) {
|
||||||
int length = e + 256; //RandomUtils.nextInt(121, 2073);
|
int length = e + 256; //RandomUtils.nextInt(121, 2073);
|
||||||
|
|
|
@ -54,9 +54,8 @@ public class TestEigen extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
// test of functions added by Luke Czapla
|
// test of functions added by Luke Czapla
|
||||||
// Compares solution of A x = L x to solution to A x = L B x when it is simple
|
// Compares solution of A x = L x to solution to A x = L B x when it is simple
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void test2Syev(Nd4jBackend backend) {
|
public void test2Syev(Nd4jBackend backend) {
|
||||||
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
||||||
Nd4j.setDefaultDataTypes(dt, dt);
|
Nd4j.setDefaultDataTypes(dt, dt);
|
||||||
|
@ -75,9 +74,8 @@ public class TestEigen extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSyev(Nd4jBackend backend) {
|
public void testSyev(Nd4jBackend backend) {
|
||||||
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
||||||
//log.info("Datatype: {}", dt);
|
//log.info("Datatype: {}", dt);
|
||||||
|
|
|
@ -37,9 +37,8 @@ import org.nd4j.common.util.ArrayUtil;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ToStringTest extends BaseNd4jTestWithBackends {
|
public class ToStringTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testToString(Nd4jBackend backend) throws Exception {
|
public void testToString(Nd4jBackend backend) throws Exception {
|
||||||
assertEquals("[ 1, 2, 3]",
|
assertEquals("[ 1, 2, 3]",
|
||||||
Nd4j.createFromArray(1, 2, 3).toString());
|
Nd4j.createFromArray(1, 2, 3).toString());
|
||||||
|
@ -57,9 +56,8 @@ public class ToStringTest extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(6, true, 1));
|
Nd4j.createFromArray(1, 2, 3, 4, 5, 6, 7, 8).toString(6, true, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testToStringScalars(){
|
public void testToStringScalars(){
|
||||||
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
|
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
|
||||||
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};
|
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};
|
||||||
|
|
|
@ -53,8 +53,9 @@ import java.util.Arrays;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertTrue;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
|
||||||
public class TestActivation extends BaseNd4jTestWithBackends {
|
public class TestActivation extends BaseNd4jTestWithBackends {
|
||||||
|
@ -76,9 +77,8 @@ public class TestActivation extends BaseNd4jTestWithBackends {
|
||||||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRelu(Nd4jBackend backend){
|
public void testRelu(Nd4jBackend backend){
|
||||||
|
|
||||||
Double[] max = {null, 6.0, 2.5, 5.0};
|
Double[] max = {null, 6.0, 2.5, 5.0};
|
||||||
|
@ -130,9 +130,8 @@ public class TestActivation extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testJson(Nd4jBackend backend) throws Exception {
|
public void testJson(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25),
|
IActivation[] activations = new IActivation[] {new ActivationCube(), new ActivationELU(0.25),
|
||||||
|
@ -179,7 +178,7 @@ public class TestActivation extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
for (String s : expFields) {
|
for (String s : expFields) {
|
||||||
msg = "Expected field \"" + s + "\", was not found in " + activations[i].toString();
|
msg = "Expected field \"" + s + "\", was not found in " + activations[i].toString();
|
||||||
assertTrue(msg, actualFieldsByName.contains(s));
|
assertTrue(actualFieldsByName.contains(s),msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Test conversion from JSON:
|
//Test conversion from JSON:
|
||||||
|
|
|
@ -30,9 +30,8 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
public class TestBackend extends BaseNd4jTestWithBackends {
|
public class TestBackend extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testBuildInfo(Nd4jBackend backend){
|
public void testBuildInfo(Nd4jBackend backend){
|
||||||
System.out.println("Backend build info: " + backend.buildInfo());
|
System.out.println("Backend build info: " + backend.buildInfo());
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,9 +37,8 @@ public class TestEnvironment extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testEnvironment(Nd4jBackend backend){
|
public void testEnvironment(Nd4jBackend backend){
|
||||||
Environment e = Nd4j.getEnvironment();
|
Environment e = Nd4j.getEnvironment();
|
||||||
System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion());
|
System.out.println("BLAS version: " + e.blasMajorVersion() + "." + e.blasMinorVersion() + "." + e.blasPatchVersion());
|
||||||
|
|
|
@ -44,9 +44,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBufferCreation(Nd4jBackend backend) {
|
public void testBufferCreation(Nd4jBackend backend) {
|
||||||
DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2});
|
DataBuffer dataBuffer = Nd4j.createBuffer(new float[] {1, 2});
|
||||||
Pointer pointer = dataBuffer.pointer();
|
Pointer pointer = dataBuffer.pointer();
|
||||||
|
@ -68,7 +67,7 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
@Test
|
@Test
|
||||||
@Disabled
|
@Disabled
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCreateNpy() throws Exception {
|
public void testCreateNpy() throws Exception {
|
||||||
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile());
|
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile());
|
||||||
assertEquals(2, arrCreate.size(0));
|
assertEquals(2, arrCreate.size(0));
|
||||||
|
@ -83,7 +82,7 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
@Test
|
@Test
|
||||||
@Disabled
|
@Disabled
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCreateNpz(Nd4jBackend backend) throws Exception {
|
public void testCreateNpz(Nd4jBackend backend) throws Exception {
|
||||||
Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile());
|
Map<String, INDArray> map = Nd4j.createFromNpzFile(new ClassPathResource("nd4j-tests/test.npz").getFile());
|
||||||
assertEquals(true, map.containsKey("x"));
|
assertEquals(true, map.containsKey("x"));
|
||||||
|
|
|
@ -35,9 +35,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||||
public class TestNDArrayCreationUtil extends BaseNd4jTestWithBackends {
|
public class TestNDArrayCreationUtil extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testShapes() {
|
public void testShapes() {
|
||||||
|
|
||||||
long[] shape2d = {2, 3};
|
long[] shape2d = {2, 3};
|
||||||
|
|
|
@ -32,9 +32,8 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
public class TestNamespaces extends BaseNd4jTestWithBackends {
|
public class TestNamespaces extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBitwiseSimple(Nd4jBackend backend){
|
public void testBitwiseSimple(Nd4jBackend backend){
|
||||||
|
|
||||||
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
|
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(100000).castTo(DataType.INT);
|
||||||
|
@ -50,9 +49,8 @@ public class TestNamespaces extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMathSimple(Nd4jBackend backend) {
|
public void testMathSimple(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1);
|
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 5).muli(2).subi(1);
|
||||||
INDArray abs = Nd4j.math.abs(x);
|
INDArray abs = Nd4j.math.abs(x);
|
||||||
|
@ -67,9 +65,8 @@ public class TestNamespaces extends BaseNd4jTestWithBackends {
|
||||||
// System.out.println(cm);
|
// System.out.println(cm);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRandomSimple(Nd4jBackend backend){
|
public void testRandomSimple(Nd4jBackend backend){
|
||||||
INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10);
|
INDArray normal = Nd4j.random.normal(0, 1, DataType.FLOAT, 10);
|
||||||
// System.out.println(normal);
|
// System.out.println(normal);
|
||||||
|
@ -77,9 +74,8 @@ public class TestNamespaces extends BaseNd4jTestWithBackends {
|
||||||
// System.out.println(uniform);
|
// System.out.println(uniform);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNeuralNetworkSimple(Nd4jBackend backend){
|
public void testNeuralNetworkSimple(Nd4jBackend backend){
|
||||||
INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10));
|
INDArray out = Nd4j.nn.elu(Nd4j.random.normal(0, 1, DataType.FLOAT, 10));
|
||||||
// System.out.println(out);
|
// System.out.println(out);
|
||||||
|
|
|
@ -36,9 +36,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
public class LapackTest extends BaseNd4jTestWithBackends {
|
public class LapackTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testQRSquare(Nd4jBackend backend) {
|
public void testQRSquare(Nd4jBackend backend) {
|
||||||
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||||
A = A.reshape('c', 3, 3);
|
A = A.reshape('c', 3, 3);
|
||||||
|
@ -56,9 +55,8 @@ public class LapackTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testQRRect(Nd4jBackend backend) {
|
public void testQRRect(Nd4jBackend backend) {
|
||||||
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
A = A.reshape('f', 4, 3);
|
A = A.reshape('f', 4, 3);
|
||||||
|
@ -76,9 +74,8 @@ public class LapackTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCholeskyL(Nd4jBackend backend) {
|
public void testCholeskyL(Nd4jBackend backend) {
|
||||||
INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,});
|
INDArray A = Nd4j.create(new double[] {2, -1, 1, -1, 2, -1, 1, -1, 2,});
|
||||||
A = A.reshape('c', 3, 3);
|
A = A.reshape('c', 3, 3);
|
||||||
|
@ -95,9 +92,8 @@ public class LapackTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCholeskyU(Nd4jBackend backend) {
|
public void testCholeskyU(Nd4jBackend backend) {
|
||||||
INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,});
|
INDArray A = Nd4j.create(new double[] {3, -1, 2, -1, 3, -1, 2, -1, 3,});
|
||||||
A = A.reshape('f', 3, 3);
|
A = A.reshape('f', 3, 3);
|
||||||
|
|
|
@ -39,9 +39,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class Level1Test extends BaseNd4jTestWithBackends {
|
public class Level1Test extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDot(Nd4jBackend backend) {
|
public void testDot(Nd4jBackend backend) {
|
||||||
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
|
@ -54,9 +53,8 @@ public class Level1Test extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAxpy(Nd4jBackend backend) {
|
public void testAxpy(Nd4jBackend backend) {
|
||||||
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
INDArray row = matrix.getRow(1);
|
INDArray row = matrix.getRow(1);
|
||||||
|
@ -65,9 +63,8 @@ public class Level1Test extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAxpy2(Nd4jBackend backend) {
|
public void testAxpy2(Nd4jBackend backend) {
|
||||||
val rowX = Nd4j.create(new double[]{1, 2, 3, 4});
|
val rowX = Nd4j.create(new double[]{1, 2, 3, 4});
|
||||||
val rowY = Nd4j.create(new double[]{1, 2, 3, 4});
|
val rowY = Nd4j.create(new double[]{1, 2, 3, 4});
|
||||||
|
|
|
@ -34,9 +34,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class Level2Test extends BaseNd4jTestWithBackends {
|
public class Level2Test extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemv1(Nd4jBackend backend) {
|
public void testGemv1(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||||
|
@ -50,9 +49,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(1853350f, array3.getFloat(3), 0.001f);
|
assertEquals(1853350f, array3.getFloat(3), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemv2(Nd4jBackend backend) {
|
public void testGemv2(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
|
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
|
||||||
|
@ -66,9 +64,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(1853350f, array3.getFloat(3), 0.001f);
|
assertEquals(1853350f, array3.getFloat(3), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemv3(Nd4jBackend backend) {
|
public void testGemv3(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
|
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
|
||||||
|
@ -82,9 +79,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(3353200f, array3.getFloat(3), 0.001f);
|
assertEquals(3353200f, array3.getFloat(3), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemv4(Nd4jBackend backend) {
|
public void testGemv4(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||||
|
@ -98,9 +94,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(3353200f, array3.getFloat(3), 0.001f);
|
assertEquals(3353200f, array3.getFloat(3), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemv5(Nd4jBackend backend) {
|
public void testGemv5(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||||
|
@ -116,9 +111,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(1853350f, array3.getFloat(3), 0.001f);
|
assertEquals(1853350f, array3.getFloat(3), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemv6(Nd4jBackend backend) {
|
public void testGemv6(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||||
|
@ -134,9 +128,8 @@ public class Level2Test extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(3353200f, array3.getFloat(3), 0.001f);
|
assertEquals(3353200f, array3.getFloat(3), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemv7(Nd4jBackend backend) {
|
public void testGemv7(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||||
|
|
|
@ -34,9 +34,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class Level3Test extends BaseNd4jTestWithBackends {
|
public class Level3Test extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemm1(Nd4jBackend backend) {
|
public void testGemm1(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100);
|
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape(1, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape(100, 1);
|
||||||
|
@ -46,9 +45,8 @@ public class Level3Test extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(338350f, array3.getFloat(0), 0.001f);
|
assertEquals(338350f, array3.getFloat(0), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemm2(Nd4jBackend backend) {
|
public void testGemm2(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100);
|
INDArray array1 = Nd4j.linspace(1, 100, 100).reshape('f', 1, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
|
INDArray array2 = Nd4j.linspace(1, 100, 100).reshape('f', 100, 1);
|
||||||
|
@ -58,9 +56,8 @@ public class Level3Test extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(338350f, array3.getFloat(0), 0.001f);
|
assertEquals(338350f, array3.getFloat(0), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemm3(Nd4jBackend backend) {
|
public void testGemm3(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
|
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
|
||||||
|
@ -78,9 +75,8 @@ public class Level3Test extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(8328150.0f, array3.data().getFloat(21), 0.001f);
|
assertEquals(8328150.0f, array3.data().getFloat(21), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemm4(Nd4jBackend backend) {
|
public void testGemm4(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape(10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);
|
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);
|
||||||
|
@ -97,9 +93,8 @@ public class Level3Test extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(3853350f, array3.data().getFloat(21), 0.001f);
|
assertEquals(3853350f, array3.data().getFloat(21), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemm5(Nd4jBackend backend) {
|
public void testGemm5(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
|
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape(100, 10);
|
||||||
|
@ -113,9 +108,8 @@ public class Level3Test extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(3.3835E7f, array3.data().getFloat(99), 10f);
|
assertEquals(3.3835E7f, array3.data().getFloat(99), 10f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemm6(Nd4jBackend backend) {
|
public void testGemm6(Nd4jBackend backend) {
|
||||||
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
INDArray array1 = Nd4j.linspace(1, 1000, 1000).reshape('f', 10, 100);
|
||||||
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);
|
INDArray array2 = Nd4j.linspace(1, 1000, 1000).reshape('f', 100, 10);
|
||||||
|
|
|
@ -37,9 +37,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
public class ParamsTestsF extends BaseNd4jTestWithBackends {
|
public class ParamsTestsF extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGemm (Nd4jBackend backend) {
|
public void testGemm (Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.create(2, 2);
|
INDArray a = Nd4j.create(2, 2);
|
||||||
INDArray b = Nd4j.create(2, 3);
|
INDArray b = Nd4j.create(2, 3);
|
||||||
|
|
|
@ -53,7 +53,7 @@ public class DataBufferTests extends BaseNd4jTestWithBackends {
|
||||||
@Test
|
@Test
|
||||||
@Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657")
|
@Disabled("AB 2019/06/03 - CI issue: \"CUDA stream synchronization failed\" - see issue 7657")
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNoArgCreateBufferFromArray(Nd4jBackend backend) {
|
public void testNoArgCreateBufferFromArray(Nd4jBackend backend) {
|
||||||
|
|
||||||
//Tests here:
|
//Tests here:
|
||||||
|
@ -279,9 +279,8 @@ public class DataBufferTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCreateTypedBuffer(Nd4jBackend backend) {
|
public void testCreateTypedBuffer(Nd4jBackend backend) {
|
||||||
|
|
||||||
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
||||||
|
@ -351,9 +350,8 @@ public class DataBufferTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAsBytes(Nd4jBackend backend) {
|
public void testAsBytes(Nd4jBackend backend) {
|
||||||
INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1);
|
INDArray orig = Nd4j.linspace(DataType.INT, 0, 10, 1);
|
||||||
|
|
||||||
|
@ -408,9 +406,8 @@ public class DataBufferTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testEnsureLocation(){
|
public void testEnsureLocation(){
|
||||||
//https://github.com/eclipse/deeplearning4j/issues/8783
|
//https://github.com/eclipse/deeplearning4j/issues/8783
|
||||||
Nd4j.create(1);
|
Nd4j.create(1);
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
||||||
*/
|
*/
|
||||||
@Test()
|
@Test()
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBlasValidation1(Nd4jBackend backend) {
|
public void testBlasValidation1(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
INDArray x = Nd4j.create(10);
|
INDArray x = Nd4j.create(10);
|
||||||
|
@ -91,7 +91,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
||||||
*/
|
*/
|
||||||
@Test()
|
@Test()
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBlasValidation2(Nd4jBackend backend) {
|
public void testBlasValidation2(Nd4jBackend backend) {
|
||||||
assertThrows(RuntimeException.class,() -> {
|
assertThrows(RuntimeException.class,() -> {
|
||||||
INDArray a = Nd4j.create(100, 10);
|
INDArray a = Nd4j.create(100, 10);
|
||||||
|
@ -111,7 +111,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
||||||
*/
|
*/
|
||||||
@Test()
|
@Test()
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBlasValidation3(Nd4jBackend backend) {
|
public void testBlasValidation3(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
INDArray x = Nd4j.create(100, 100);
|
INDArray x = Nd4j.create(100, 100);
|
||||||
|
|
|
@ -76,9 +76,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
DataTypeUtil.setDTypeForContext(initialType);
|
DataTypeUtil.setDTypeForContext(initialType);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPointerCreation(Nd4jBackend backend) {
|
public void testPointerCreation(Nd4jBackend backend) {
|
||||||
DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4);
|
DoublePointer floatPointer = new DoublePointer(1, 2, 3, 4);
|
||||||
Indexer indexer = DoubleIndexer.create(floatPointer);
|
Indexer indexer = DoubleIndexer.create(floatPointer);
|
||||||
|
@ -87,9 +86,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(other.asDouble(), buffer.asDouble(), 0.001);
|
assertArrayEquals(other.asDouble(), buffer.asDouble(), 0.001);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testGetSet(Nd4jBackend backend) {
|
public void testGetSet(Nd4jBackend backend) {
|
||||||
double[] d1 = new double[] {1, 2, 3, 4};
|
double[] d1 = new double[] {1, 2, 3, 4};
|
||||||
DataBuffer d = Nd4j.createBuffer(d1);
|
DataBuffer d = Nd4j.createBuffer(d1);
|
||||||
|
@ -100,9 +98,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testSerialization2() throws Exception {
|
public void testSerialization2() throws Exception {
|
||||||
INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10),
|
INDArray[] arr = new INDArray[] {Nd4j.ones(1, 10),
|
||||||
// Nd4j.ones(5,10).getRow(2)
|
// Nd4j.ones(5,10).getRow(2)
|
||||||
|
@ -130,9 +127,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testSerialization(@TempDir Path testDir) throws Exception {
|
public void testSerialization(@TempDir Path testDir) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
DataBuffer buf = Nd4j.createBuffer(5);
|
DataBuffer buf = Nd4j.createBuffer(5);
|
||||||
|
@ -154,9 +150,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testDup(Nd4jBackend backend) {
|
public void testDup(Nd4jBackend backend) {
|
||||||
double[] d1 = new double[] {1, 2, 3, 4};
|
double[] d1 = new double[] {1, 2, 3, 4};
|
||||||
DataBuffer d = Nd4j.createBuffer(d1);
|
DataBuffer d = Nd4j.createBuffer(d1);
|
||||||
|
@ -166,9 +161,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testPut(Nd4jBackend backend) {
|
public void testPut(Nd4jBackend backend) {
|
||||||
double[] d1 = new double[] {1, 2, 3, 4};
|
double[] d1 = new double[] {1, 2, 3, 4};
|
||||||
DataBuffer d = Nd4j.createBuffer(d1);
|
DataBuffer d = Nd4j.createBuffer(d1);
|
||||||
|
@ -179,9 +173,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testGetRange(Nd4jBackend backend) {
|
public void testGetRange(Nd4jBackend backend) {
|
||||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
|
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
|
||||||
double[] get = buffer.getDoublesAt(0, 3);
|
double[] get = buffer.getDoublesAt(0, 3);
|
||||||
|
@ -196,9 +189,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testGetOffsetRange(Nd4jBackend backend) {
|
public void testGetOffsetRange(Nd4jBackend backend) {
|
||||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
|
DataBuffer buffer = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).data();
|
||||||
double[] get = buffer.getDoublesAt(1, 3);
|
double[] get = buffer.getDoublesAt(1, 3);
|
||||||
|
@ -213,9 +205,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testAssign(Nd4jBackend backend) {
|
public void testAssign(Nd4jBackend backend) {
|
||||||
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
|
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
|
||||||
DataBuffer one = Nd4j.createBuffer(new double[] {1});
|
DataBuffer one = Nd4j.createBuffer(new double[] {1});
|
||||||
|
@ -226,9 +217,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testOffset(Nd4jBackend backend) {
|
public void testOffset(Nd4jBackend backend) {
|
||||||
DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2);
|
DataBuffer create = Nd4j.createBuffer(new double[] {1, 2, 3, 4}, 2);
|
||||||
assertEquals(2, create.length());
|
assertEquals(2, create.length());
|
||||||
|
@ -238,9 +228,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testReallocation(Nd4jBackend backend) {
|
public void testReallocation(Nd4jBackend backend) {
|
||||||
DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4});
|
DataBuffer buffer = Nd4j.createBuffer(new double[] {1, 2, 3, 4});
|
||||||
assertEquals(4, buffer.capacity());
|
assertEquals(4, buffer.capacity());
|
||||||
|
@ -250,9 +239,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(old, Arrays.copyOf(buffer.asDouble(), 4), 1e-1);
|
assertArrayEquals(old, Arrays.copyOf(buffer.asDouble(), 4), 1e-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testReallocationWorkspace(Nd4jBackend backend) {
|
public void testReallocationWorkspace(Nd4jBackend backend) {
|
||||||
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
||||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
||||||
|
@ -269,9 +257,8 @@ public class DoubleDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
|
||||||
public void testAddressPointer(){
|
public void testAddressPointer(){
|
||||||
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
|
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -72,9 +72,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPointerCreation(Nd4jBackend backend) {
|
public void testPointerCreation(Nd4jBackend backend) {
|
||||||
FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4);
|
FloatPointer floatPointer = new FloatPointer(1, 2, 3, 4);
|
||||||
Indexer indexer = FloatIndexer.create(floatPointer);
|
Indexer indexer = FloatIndexer.create(floatPointer);
|
||||||
|
@ -83,9 +82,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(other.asFloat(), buffer.asFloat(), 0.001f);
|
assertArrayEquals(other.asFloat(), buffer.asFloat(), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetSet(Nd4jBackend backend) {
|
public void testGetSet(Nd4jBackend backend) {
|
||||||
float[] d1 = new float[] {1, 2, 3, 4};
|
float[] d1 = new float[] {1, 2, 3, 4};
|
||||||
DataBuffer d = Nd4j.createBuffer(d1);
|
DataBuffer d = Nd4j.createBuffer(d1);
|
||||||
|
@ -96,9 +94,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSerialization(@TempDir Path tempDir,Nd4jBackend backend) throws Exception {
|
public void testSerialization(@TempDir Path tempDir,Nd4jBackend backend) throws Exception {
|
||||||
File dir = tempDir.toFile();
|
File dir = tempDir.toFile();
|
||||||
DataBuffer buf = Nd4j.createBuffer(5);
|
DataBuffer buf = Nd4j.createBuffer(5);
|
||||||
|
@ -119,9 +116,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(buf.asFloat(), buf2.asFloat(), 0.0001f);
|
assertArrayEquals(buf.asFloat(), buf2.asFloat(), 0.0001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDup(Nd4jBackend backend) {
|
public void testDup(Nd4jBackend backend) {
|
||||||
float[] d1 = new float[] {1, 2, 3, 4};
|
float[] d1 = new float[] {1, 2, 3, 4};
|
||||||
DataBuffer d = Nd4j.createBuffer(d1);
|
DataBuffer d = Nd4j.createBuffer(d1);
|
||||||
|
@ -129,9 +125,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(d.asFloat(), d2.asFloat(), 0.001f);
|
assertArrayEquals(d.asFloat(), d2.asFloat(), 0.001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testToNio(Nd4jBackend backend) {
|
public void testToNio(Nd4jBackend backend) {
|
||||||
DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT);
|
DataBuffer buff = Nd4j.createTypedBuffer(new double[] {1, 2, 3, 4}, DataType.FLOAT);
|
||||||
assertEquals(4, buff.length());
|
assertEquals(4, buff.length());
|
||||||
|
@ -143,9 +138,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPut(Nd4jBackend backend) {
|
public void testPut(Nd4jBackend backend) {
|
||||||
float[] d1 = new float[] {1, 2, 3, 4};
|
float[] d1 = new float[] {1, 2, 3, 4};
|
||||||
DataBuffer d = Nd4j.createBuffer(d1);
|
DataBuffer d = Nd4j.createBuffer(d1);
|
||||||
|
@ -156,9 +150,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetRange(Nd4jBackend backend) {
|
public void testGetRange(Nd4jBackend backend) {
|
||||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
||||||
float[] get = buffer.getFloatsAt(0, 3);
|
float[] get = buffer.getFloatsAt(0, 3);
|
||||||
|
@ -174,9 +167,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetOffsetRange(Nd4jBackend backend) {
|
public void testGetOffsetRange(Nd4jBackend backend) {
|
||||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
||||||
float[] get = buffer.getFloatsAt(1, 3);
|
float[] get = buffer.getFloatsAt(1, 3);
|
||||||
|
@ -193,9 +185,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAsBytes(Nd4jBackend backend) {
|
public void testAsBytes(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(5);
|
INDArray arr = Nd4j.create(5);
|
||||||
byte[] d = arr.data().asBytes();
|
byte[] d = arr.data().asBytes();
|
||||||
|
@ -205,9 +196,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAssign(Nd4jBackend backend) {
|
public void testAssign(Nd4jBackend backend) {
|
||||||
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
|
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
|
||||||
DataBuffer one = Nd4j.createBuffer(new double[] {1});
|
DataBuffer one = Nd4j.createBuffer(new double[] {1});
|
||||||
|
@ -217,9 +207,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(assertion.asFloat(), blank.asFloat(), 0.0001f);
|
assertArrayEquals(assertion.asFloat(), blank.asFloat(), 0.0001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReadWrite(Nd4jBackend backend) throws Exception {
|
public void testReadWrite(Nd4jBackend backend) throws Exception {
|
||||||
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
|
DataBuffer assertion = Nd4j.createBuffer(new double[] {1, 2, 3});
|
||||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||||
|
@ -233,9 +222,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(assertion.asFloat(), clone.asFloat(), 0.0001f);
|
assertArrayEquals(assertion.asFloat(), clone.asFloat(), 0.0001f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOffset(Nd4jBackend backend) {
|
public void testOffset(Nd4jBackend backend) {
|
||||||
DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2);
|
DataBuffer create = Nd4j.createBuffer(new float[] {1, 2, 3, 4}, 2);
|
||||||
assertEquals(2, create.length());
|
assertEquals(2, create.length());
|
||||||
|
@ -245,9 +233,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReallocation(Nd4jBackend backend) {
|
public void testReallocation(Nd4jBackend backend) {
|
||||||
DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4});
|
DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4});
|
||||||
assertEquals(4, buffer.capacity());
|
assertEquals(4, buffer.capacity());
|
||||||
|
@ -258,9 +245,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(old, newBuf, 1e-4F);
|
assertArrayEquals(old, newBuf, 1e-4F);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReallocationWorkspace(Nd4jBackend backend) {
|
public void testReallocationWorkspace(Nd4jBackend backend) {
|
||||||
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
||||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
||||||
|
@ -277,9 +263,8 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
workspace.close();
|
workspace.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAddressPointer(Nd4jBackend backend){
|
public void testAddressPointer(Nd4jBackend backend){
|
||||||
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
|
if( Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.NATIVE_CPU ){
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -42,9 +42,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
public class IntDataBufferTests extends BaseNd4jTestWithBackends {
|
public class IntDataBufferTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBasicSerde1() throws Exception {
|
public void testBasicSerde1() throws Exception {
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,9 +81,8 @@ public class IntDataBufferTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReallocation(Nd4jBackend backend) {
|
public void testReallocation(Nd4jBackend backend) {
|
||||||
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
|
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
|
||||||
assertEquals(4, buffer.capacity());
|
assertEquals(4, buffer.capacity());
|
||||||
|
@ -96,9 +94,8 @@ public class IntDataBufferTests extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(old, Arrays.copyOf(newContent, old.length));
|
assertArrayEquals(old, Arrays.copyOf(newContent, old.length));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReallocationWorkspace(Nd4jBackend backend) {
|
public void testReallocationWorkspace(Nd4jBackend backend) {
|
||||||
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
||||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
||||||
|
|
|
@ -43,9 +43,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testINDArrayIndexingEqualToRank(Nd4jBackend backend) {
|
public void testINDArrayIndexingEqualToRank(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
|
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
|
||||||
INDArray indexes = Nd4j.create(new double[][]{
|
INDArray indexes = Nd4j.create(new double[][]{
|
||||||
|
@ -60,9 +59,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testINDArrayIndexingLessThanRankSimple(Nd4jBackend backend) {
|
public void testINDArrayIndexingLessThanRankSimple(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
|
INDArray x = Nd4j.linspace(1,6,6, DataType.DOUBLE).reshape('c',3,2).castTo(DataType.DOUBLE);
|
||||||
INDArray indexes = Nd4j.create(new double[][]{
|
INDArray indexes = Nd4j.create(new double[][]{
|
||||||
|
@ -76,9 +74,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testINDArrayIndexingLessThanRankFourDimension(Nd4jBackend backend) {
|
public void testINDArrayIndexingLessThanRankFourDimension(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
|
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE);
|
||||||
INDArray indexes = Nd4j.create(new double[][]{
|
INDArray indexes = Nd4j.create(new double[][]{
|
||||||
|
@ -91,9 +88,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutSimple(Nd4jBackend backend) {
|
public void testPutSimple(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2);
|
INDArray x = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2);
|
||||||
INDArray indexes = Nd4j.create(new double[][]{
|
INDArray indexes = Nd4j.create(new double[][]{
|
||||||
|
@ -105,9 +101,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(vals,x);
|
assertEquals(vals,x);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetScalar(Nd4jBackend backend) {
|
public void testGetScalar(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
||||||
INDArray d = arr.get(NDArrayIndex.point(1));
|
INDArray d = arr.get(NDArrayIndex.point(1));
|
||||||
|
@ -116,18 +111,16 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNewAxis(Nd4jBackend backend) {
|
public void testNewAxis(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
INDArray arr = Nd4j.rand(new int[] {4, 2, 3});
|
||||||
INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
|
INDArray view = arr.get(NDArrayIndex.newAxis(), NDArrayIndex.all(), NDArrayIndex.point(1));
|
||||||
// System.out.println(view);
|
// System.out.println(view);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVectorIndexing(Nd4jBackend backend) {
|
public void testVectorIndexing(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE);
|
INDArray x = Nd4j.linspace(0, 10, 11, DataType.DOUBLE).reshape(1, 11).castTo(DataType.DOUBLE);
|
||||||
int[] index = new int[] {5, 8, 9};
|
int[] index = new int[] {5, 8, 9};
|
||||||
|
@ -139,9 +132,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetRowsColumnsMatrix(Nd4jBackend backend) {
|
public void testGetRowsColumnsMatrix(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6);
|
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 6);
|
||||||
INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}});
|
INDArray firstAndSecondColumnsAssertion = Nd4j.create(new double[][] {{1, 5}, {2, 6}, {3, 7}, {4, 8}});
|
||||||
|
@ -159,9 +151,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSlicing(Nd4jBackend backend) {
|
public void testSlicing(Nd4jBackend backend) {
|
||||||
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
|
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
|
||||||
INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14});
|
INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14});
|
||||||
|
@ -169,9 +160,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(slice1Assert, slice1Test);
|
assertEquals(slice1Assert, slice1Test);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testArangeMul(Nd4jBackend backend) {
|
public void testArangeMul(Nd4jBackend backend) {
|
||||||
INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE);
|
INDArray arange = Nd4j.arange(1, 17).reshape('f', 4, 4).castTo(DataType.DOUBLE);
|
||||||
INDArrayIndex index = NDArrayIndex.interval(0, 2);
|
INDArrayIndex index = NDArrayIndex.interval(0, 2);
|
||||||
|
@ -183,9 +173,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetIndicesVector(Nd4jBackend backend) {
|
public void testGetIndicesVector(Nd4jBackend backend) {
|
||||||
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
|
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
|
||||||
INDArray test = Nd4j.create(new double[] {2, 3});
|
INDArray test = Nd4j.create(new double[] {2, 3});
|
||||||
|
@ -193,9 +182,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(test, result);
|
assertEquals(test, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetIndicesVectorView(Nd4jBackend backend) {
|
public void testGetIndicesVectorView(Nd4jBackend backend) {
|
||||||
INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5);
|
INDArray matrix = Nd4j.linspace(1, 25, 25, DataType.DOUBLE).reshape('c',5, 5);
|
||||||
INDArray column = matrix.getColumn(0).reshape(1,5);
|
INDArray column = matrix.getColumn(0).reshape(1,5);
|
||||||
|
@ -213,9 +201,8 @@ public class IndexingTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp2, result);
|
assertEquals(exp2, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTest#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void test2dGetPoint(Nd4jBackend backend){
|
public void test2dGetPoint(Nd4jBackend backend){
|
||||||
INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4);
|
INDArray arr = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4);
|
||||||
for( int i=0; i<3; i++ ){
|
for( int i=0; i<3; i++ ){
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue