diff --git a/.github/workflows/cpu-integration-tests.yaml b/.github/workflows/cpu-integration-tests.yaml
index dff8b29ad..bba0e345d 100644
--- a/.github/workflows/cpu-integration-tests.yaml
+++ b/.github/workflows/cpu-integration-tests.yaml
@@ -31,7 +31,7 @@ jobs:
protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1
- mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
+ mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
windows-x86_64:
runs-on: windows-2019
@@ -44,7 +44,7 @@ jobs:
run: |
set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1
- mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
+ mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
@@ -60,5 +60,5 @@ jobs:
run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1
- mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
+ mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
diff --git a/.github/workflows/cpu-sanity-check-tests.yaml b/.github/workflows/cpu-sanity-check-tests.yaml
index 2737672bc..fbc2514cf 100644
--- a/.github/workflows/cpu-sanity-check-tests.yaml
+++ b/.github/workflows/cpu-sanity-check-tests.yaml
@@ -31,7 +31,7 @@ jobs:
protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1
- mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
+ mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
windows-x86_64:
runs-on: windows-2019
@@ -44,7 +44,7 @@ jobs:
run: |
set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1
- mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
+ mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
@@ -60,5 +60,5 @@ jobs:
run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1
- mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
+ mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
diff --git a/.github/workflows/run-cpu-tests-sanity-checks.yml b/.github/workflows/run-cpu-tests-sanity-checks.yml
index 47202170c..c44ae3f03 100644
--- a/.github/workflows/run-cpu-tests-sanity-checks.yml
+++ b/.github/workflows/run-cpu-tests-sanity-checks.yml
@@ -34,5 +34,5 @@ jobs:
cmake --version
protoc --version
export OMP_NUM_THREADS=1
- mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptest-nd4j-native --also-make clean test
+ mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Pnd4j-tests-cpu --also-make clean test
diff --git a/datavec/datavec-api/pom.xml b/datavec/datavec-api/pom.xml
index fc091c5dd..0c7971201 100644
--- a/datavec/datavec-api/pom.xml
+++ b/datavec/datavec-api/pom.xml
@@ -109,10 +109,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java
index 87d313ded..f59a264df 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java
@@ -30,6 +30,8 @@ import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.loader.FileBatch;
import java.io.File;
@@ -40,13 +42,16 @@ import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
+import org.nd4j.linalg.factory.Nd4jBackend;
@DisplayName("File Batch Record Reader Test")
-class FileBatchRecordReaderTest extends BaseND4JTest {
+public class FileBatchRecordReaderTest extends BaseND4JTest {
+ @TempDir Path testDir;
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Csv")
- void testCsv(@TempDir Path testDir) throws Exception {
+ void testCsv(Nd4jBackend backend) throws Exception {
// This is an unrealistic use case - one line/record per CSV
File baseDir = testDir.toFile();
List fileList = new ArrayList<>();
@@ -75,9 +80,10 @@ class FileBatchRecordReaderTest extends BaseND4JTest {
}
}
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Csv Sequence")
- void testCsvSequence(@TempDir Path testDir) throws Exception {
+ void testCsvSequence(Nd4jBackend backend) throws Exception {
// CSV sequence - 3 lines per file, 10 files
File baseDir = testDir.toFile();
List fileList = new ArrayList<>();
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java
index c2549b405..fa1d82279 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java
@@ -21,7 +21,6 @@ package org.datavec.api.transform.ops;
import org.junit.jupiter.api.Test;
-import org.junit.rules.ExpectedException;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
diff --git a/datavec/datavec-arrow/pom.xml b/datavec/datavec-arrow/pom.xml
index 0d30f07a9..f19f5d6ba 100644
--- a/datavec/datavec-arrow/pom.xml
+++ b/datavec/datavec-arrow/pom.xml
@@ -60,10 +60,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/datavec/datavec-data/datavec-data-image/pom.xml b/datavec/datavec-data/datavec-data-image/pom.xml
index 20f4a7d9e..1b786b59a 100644
--- a/datavec/datavec-data/datavec-data-image/pom.xml
+++ b/datavec/datavec-data/datavec-data-image/pom.xml
@@ -119,10 +119,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/datavec/datavec-data/pom.xml b/datavec/datavec-data/pom.xml
index d5bfd6d05..8ed687669 100644
--- a/datavec/datavec-data/pom.xml
+++ b/datavec/datavec-data/pom.xml
@@ -59,10 +59,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/datavec/datavec-excel/pom.xml b/datavec/datavec-excel/pom.xml
index 9b532ca1e..7e3d2dbd2 100644
--- a/datavec/datavec-excel/pom.xml
+++ b/datavec/datavec-excel/pom.xml
@@ -57,10 +57,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/datavec/datavec-jdbc/pom.xml b/datavec/datavec-jdbc/pom.xml
index 39bd2cff1..0339dbe98 100644
--- a/datavec/datavec-jdbc/pom.xml
+++ b/datavec/datavec-jdbc/pom.xml
@@ -65,10 +65,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/datavec/datavec-local/pom.xml b/datavec/datavec-local/pom.xml
index 195ed2cb4..9f0480274 100644
--- a/datavec/datavec-local/pom.xml
+++ b/datavec/datavec-local/pom.xml
@@ -61,25 +61,18 @@
nd4j-common
- org.datavec
- datavec-geo
+ org.nd4j
+ python4j-numpy
${project.version}
- test
-
-
- org.datavec
- datavec-python
- ${project.version}
- test
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java
index 4a85c255b..8284d22b7 100644
--- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java
+++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java
@@ -29,7 +29,6 @@ import org.datavec.api.transform.reduce.Reducer;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.*;
-import org.datavec.python.PythonTransform;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
@@ -39,7 +38,6 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
-import org.junit.jupiter.api.extension.ExtendWith;
import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout;
@@ -166,37 +164,8 @@ class ExecutionTest {
List> out = outRdd;
List> expOut = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out);
- Collections.sort(out, new Comparator>() {
-
- @Override
- public int compare(List o1, List o2) {
- return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
- }
- });
+ Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
assertEquals(expOut, out);
}
- @Test
- @Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
- @DisplayName("Test Python Execution Ndarray")
- void testPythonExecutionNdarray() {
- assertTimeout(ofMillis(60000), () -> {
- Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
- TransformProcess transformProcess = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
- List> functions = new ArrayList<>();
- List firstRow = new ArrayList<>();
- INDArray firstArr = Nd4j.linspace(1, 4, 4);
- INDArray secondArr = Nd4j.linspace(1, 4, 4);
- firstRow.add(new NDArrayWritable(firstArr));
- firstRow.add(new NDArrayWritable(secondArr));
- functions.add(firstRow);
- List> execute = LocalTransformExecutor.execute(functions, transformProcess);
- INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
- INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
- INDArray expected = Transforms.sin(firstArr);
- INDArray secondExpected = Transforms.cos(secondArr);
- assertEquals(expected, firstResult);
- assertEquals(secondExpected, secondResult);
- });
- }
}
diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java
deleted file mode 100644
index f81fdfd2e..000000000
--- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java
+++ /dev/null
@@ -1,155 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.local.transforms.transform;
-
-import org.datavec.api.transform.ColumnType;
-import org.datavec.api.transform.Transform;
-import org.datavec.api.transform.geo.LocationType;
-import org.datavec.api.transform.schema.Schema;
-import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
-import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
-import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
-import org.datavec.api.writable.DoubleWritable;
-import org.datavec.api.writable.Text;
-import org.datavec.api.writable.Writable;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Test;
-import org.nd4j.common.io.ClassPathResource;
-
-import java.io.*;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-/**
- * @author saudet
- */
-public class TestGeoTransforms {
-
- @BeforeAll
- public static void beforeClass() throws Exception {
- //Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
- File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
- System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
- }
-
- @AfterClass
- public static void afterClass(){
- System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
- }
-
-
- @Test
- public void testCoordinatesDistanceTransform() throws Exception {
- Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
- .build();
-
- Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
- transform.setInputSchema(schema);
-
- Schema out = transform.transform(schema);
- assertEquals(4, out.numColumns());
- assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
- assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
- out.getColumnTypes());
-
- assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
- transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
- assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
- new DoubleWritable(Math.sqrt(160))),
- transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
- new Text("10|5"))));
- }
-
- @Test
- public void testIPAddressToCoordinatesTransform() throws Exception {
- Schema schema = new Schema.Builder().addColumnString("column").build();
-
- Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
- transform.setInputSchema(schema);
-
- Schema out = transform.transform(schema);
-
- assertEquals(1, out.getColumnMetaData().size());
- assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
-
- String in = "81.2.69.160";
- double latitude = 51.5142;
- double longitude = -0.0931;
-
- List writables = transform.map(Collections.singletonList((Writable) new Text(in)));
- assertEquals(1, writables.size());
- String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
- assertEquals(2, coordinates.length);
- assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
- assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
-
- //Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
- ObjectOutputStream oos = new ObjectOutputStream(baos);
- oos.writeObject(transform);
-
- byte[] bytes = baos.toByteArray();
-
- ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
- ObjectInputStream ois = new ObjectInputStream(bais);
-
- Transform deserialized = (Transform) ois.readObject();
- writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
- assertEquals(1, writables.size());
- coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
- //System.out.println(Arrays.toString(coordinates));
- assertEquals(2, coordinates.length);
- assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
- assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
- }
-
- @Test
- public void testIPAddressToLocationTransform() throws Exception {
- Schema schema = new Schema.Builder().addColumnString("column").build();
- LocationType[] locationTypes = LocationType.values();
- String in = "81.2.69.160";
- String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
- "51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
-
- for (int i = 0; i < locationTypes.length; i++) {
- LocationType locationType = locationTypes[i];
- String location = locations[i];
-
- Transform transform = new IPAddressToLocationTransform("column", locationType);
- transform.setInputSchema(schema);
-
- Schema out = transform.transform(schema);
-
- assertEquals(1, out.getColumnMetaData().size());
- assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
-
- List writables = transform.map(Collections.singletonList((Writable) new Text(in)));
- assertEquals(1, writables.size());
- assertEquals(location, writables.get(0).toString());
- //System.out.println(location);
- }
- }
-}
diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java
deleted file mode 100644
index 2ef20194d..000000000
--- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java
+++ /dev/null
@@ -1,386 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.local.transforms.transform;
-
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.api.transform.condition.Condition;
-import org.datavec.api.transform.filter.ConditionFilter;
-import org.datavec.api.transform.filter.Filter;
-import org.datavec.api.transform.schema.Schema;
-import org.datavec.local.transforms.LocalTransformExecutor;
-
-import org.datavec.api.writable.*;
-import org.datavec.python.PythonCondition;
-import org.datavec.python.PythonTransform;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.Timeout;
-import org.nd4j.linalg.api.buffer.DataType;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
-
-import javax.annotation.concurrent.NotThreadSafe;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-
-
-import static org.datavec.api.transform.schema.Schema.Builder;
-import static org.junit.jupiter.api.Assertions.*;
-
-@NotThreadSafe
-public class TestPythonTransformProcess {
-
-
- @Test()
- public void testStringConcat() throws Exception{
- Builder schemaBuilder = new Builder();
- schemaBuilder
- .addColumnString("col1")
- .addColumnString("col2");
-
- Schema initialSchema = schemaBuilder.build();
- schemaBuilder.addColumnString("col3");
- Schema finalSchema = schemaBuilder.build();
-
- String pythonCode = "col3 = col1 + col2";
-
- TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- PythonTransform.builder().code(pythonCode)
- .outputSchema(finalSchema)
- .build()
- ).build();
-
- List inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
-
- List outputs = tp.execute(inputs);
- assertEquals((outputs.get(0)).toString(), "Hello ");
- assertEquals((outputs.get(1)).toString(), "World!");
- assertEquals((outputs.get(2)).toString(), "Hello World!");
-
- }
-
- @Test()
- @Timeout(60000L)
- public void testMixedTypes() throws Exception {
- Builder schemaBuilder = new Builder();
- schemaBuilder
- .addColumnInteger("col1")
- .addColumnFloat("col2")
- .addColumnString("col3")
- .addColumnDouble("col4");
-
-
- Schema initialSchema = schemaBuilder.build();
- schemaBuilder.addColumnInteger("col5");
- Schema finalSchema = schemaBuilder.build();
-
- String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
-
- TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- PythonTransform.builder().code(pythonCode)
- .outputSchema(finalSchema)
- .inputSchema(initialSchema)
- .build() ).build();
-
- List inputs = Arrays.asList(new IntWritable(10),
- new FloatWritable(3.5f),
- new Text("5"),
- new DoubleWritable(2.0)
- );
-
- List outputs = tp.execute(inputs);
- assertEquals(((LongWritable)outputs.get(4)).get(), 36);
- }
-
- @Test()
- @Timeout(60000L)
- public void testNDArray() throws Exception {
- long[] shape = new long[]{3, 2};
- INDArray arr1 = Nd4j.rand(shape);
- INDArray arr2 = Nd4j.rand(shape);
-
- INDArray expectedOutput = arr1.add(arr2);
-
- Builder schemaBuilder = new Builder();
- schemaBuilder
- .addColumnNDArray("col1", shape)
- .addColumnNDArray("col2", shape);
-
- Schema initialSchema = schemaBuilder.build();
- schemaBuilder.addColumnNDArray("col3", shape);
- Schema finalSchema = schemaBuilder.build();
-
- String pythonCode = "col3 = col1 + col2";
- TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- PythonTransform.builder().code(pythonCode)
- .outputSchema(finalSchema)
- .build() ).build();
-
- List inputs = Arrays.asList(
- (Writable)
- new NDArrayWritable(arr1),
- new NDArrayWritable(arr2)
- );
-
- List outputs = tp.execute(inputs);
- assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
- assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
- assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
-
- }
-
- @Test()
- @Timeout(60000L)
- public void testNDArray2() throws Exception {
- long[] shape = new long[]{3, 2};
- INDArray arr1 = Nd4j.rand(shape);
- INDArray arr2 = Nd4j.rand(shape);
-
- INDArray expectedOutput = arr1.add(arr2);
-
- Builder schemaBuilder = new Builder();
- schemaBuilder
- .addColumnNDArray("col1", shape)
- .addColumnNDArray("col2", shape);
-
- Schema initialSchema = schemaBuilder.build();
- schemaBuilder.addColumnNDArray("col3", shape);
- Schema finalSchema = schemaBuilder.build();
-
- String pythonCode = "col3 = col1 + col2";
- TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- PythonTransform.builder().code(pythonCode)
- .outputSchema(finalSchema)
- .build() ).build();
-
- List inputs = Arrays.asList(
- (Writable)
- new NDArrayWritable(arr1),
- new NDArrayWritable(arr2)
- );
-
- List outputs = tp.execute(inputs);
- assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
- assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
- assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
-
- }
-
- @Test()
- @Timeout(60000L)
- public void testNDArrayMixed() throws Exception{
- long[] shape = new long[]{3, 2};
- INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
- INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
- INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
-
- Builder schemaBuilder = new Builder();
- schemaBuilder
- .addColumnNDArray("col1", shape)
- .addColumnNDArray("col2", shape);
-
- Schema initialSchema = schemaBuilder.build();
- schemaBuilder.addColumnNDArray("col3", shape);
- Schema finalSchema = schemaBuilder.build();
-
- String pythonCode = "col3 = col1 + col2";
- TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- PythonTransform.builder().code(pythonCode)
- .outputSchema(finalSchema)
- .build()
- ).build();
-
- List inputs = Arrays.asList(
- (Writable)
- new NDArrayWritable(arr1),
- new NDArrayWritable(arr2)
- );
-
- List outputs = tp.execute(inputs);
- assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
- assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
- assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
-
- }
-
- @Test()
- @Timeout(60000L)
- public void testPythonFilter() {
- Schema schema = new Builder().addColumnInteger("column").build();
-
- Condition condition = new PythonCondition(
- "f = lambda: column < 0"
- );
-
- condition.setInputSchema(schema);
-
- Filter filter = new ConditionFilter(condition);
-
- assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
- assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
- assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
- assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
- assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
-
- }
-
- @Test()
- @Timeout(60000L)
- public void testPythonFilterAndTransform() throws Exception {
- Builder schemaBuilder = new Builder();
- schemaBuilder
- .addColumnInteger("col1")
- .addColumnFloat("col2")
- .addColumnString("col3")
- .addColumnDouble("col4");
-
- Schema initialSchema = schemaBuilder.build();
- schemaBuilder.addColumnString("col6");
- Schema finalSchema = schemaBuilder.build();
-
- Condition condition = new PythonCondition(
- "f = lambda: col1 < 0 and col2 > 10.0"
- );
-
- condition.setInputSchema(initialSchema);
-
- Filter filter = new ConditionFilter(condition);
-
- String pythonCode = "col6 = str(col1 + col2)";
- TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
- PythonTransform.builder().code(pythonCode)
- .outputSchema(finalSchema)
- .build()
- ).filter(
- filter
- ).build();
-
- List> inputs = new ArrayList<>();
- inputs.add(
- Arrays.asList(
- (Writable)
- new IntWritable(5),
- new FloatWritable(3.0f),
- new Text("abcd"),
- new DoubleWritable(2.1))
- );
- inputs.add(
- Arrays.asList(
- (Writable)
- new IntWritable(-3),
- new FloatWritable(3.0f),
- new Text("abcd"),
- new DoubleWritable(2.1))
- );
- inputs.add(
- Arrays.asList(
- (Writable)
- new IntWritable(5),
- new FloatWritable(11.2f),
- new Text("abcd"),
- new DoubleWritable(2.1))
- );
-
- LocalTransformExecutor.execute(inputs,tp);
- }
-
-
- @Test
- public void testPythonTransformNoOutputSpecified() throws Exception {
- PythonTransform pythonTransform = PythonTransform.builder()
- .code("a += 2; b = 'hello world'")
- .returnAllInputs(true)
- .build();
- List> inputs = new ArrayList<>();
- inputs.add(Arrays.asList((Writable)new IntWritable(1)));
- Schema inputSchema = new Builder()
- .addColumnInteger("a")
- .build();
-
- TransformProcess tp = new TransformProcess.Builder(inputSchema)
- .transform(pythonTransform)
- .build();
- List> execute = LocalTransformExecutor.execute(inputs, tp);
- assertEquals(3,execute.get(0).get(0).toInt());
- assertEquals("hello world",execute.get(0).get(1).toString());
-
- }
-
- @Test
- public void testNumpyTransform() {
- PythonTransform pythonTransform = PythonTransform.builder()
- .code("a += 2; b = 'hello world'")
- .returnAllInputs(true)
- .build();
-
- List> inputs = new ArrayList<>();
- inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
- Schema inputSchema = new Builder()
- .addColumnNDArray("a",new long[]{1,1})
- .build();
-
- TransformProcess tp = new TransformProcess.Builder(inputSchema)
- .transform(pythonTransform)
- .build();
- List> execute = LocalTransformExecutor.execute(inputs, tp);
- assertFalse(execute.isEmpty());
- assertNotNull(execute.get(0));
- assertNotNull(execute.get(0).get(0));
- assertNotNull(execute.get(0).get(1));
- assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
- assertEquals("hello world",execute.get(0).get(1).toString());
- }
-
- @Test
- public void testWithSetupRun() throws Exception {
-
- PythonTransform pythonTransform = PythonTransform.builder()
- .code("five=None\n" +
- "def setup():\n" +
- " global five\n"+
- " five = 5\n\n" +
- "def run(a, b):\n" +
- " c = a + b + five\n"+
- " return {'c':c}\n\n")
- .returnAllInputs(true)
- .setupAndRun(true)
- .build();
-
- List> inputs = new ArrayList<>();
- inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
- new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
- Schema inputSchema = new Builder()
- .addColumnNDArray("a",new long[]{1,1})
- .addColumnNDArray("b", new long[]{1, 1})
- .build();
-
- TransformProcess tp = new TransformProcess.Builder(inputSchema)
- .transform(pythonTransform)
- .build();
- List> execute = LocalTransformExecutor.execute(inputs, tp);
- assertFalse(execute.isEmpty());
- assertNotNull(execute.get(0));
- assertNotNull(execute.get(0).get(0));
- assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
- }
-
-}
\ No newline at end of file
diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml
index 27648bdfe..98d65b390 100644
--- a/datavec/datavec-spark/pom.xml
+++ b/datavec/datavec-spark/pom.xml
@@ -128,10 +128,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/datavec/pom.xml b/datavec/pom.xml
index 6c4d9496a..d307284b1 100644
--- a/datavec/pom.xml
+++ b/datavec/pom.xml
@@ -92,6 +92,10 @@
org.junit.jupiter
junit-jupiter-api
+
+ org.junit.jupiter
+ junit-jupiter-params
+
org.junit.vintage
junit-vintage-engine
@@ -154,7 +158,7 @@
${skipTestResourceEnforcement}
- test-nd4j-native,test-nd4j-cuda-11.0
+ nd4j-tests-cpu,nd4j-tests-cuda
false
@@ -163,23 +167,6 @@
-
- maven-surefire-plugin
-
-
-
-
- true
- false
-
-
org.eclipse.m2e
lifecycle-mapping
@@ -249,7 +236,7 @@
- test-nd4j-native
+ nd4j-tests-cpu
org.nd4j
@@ -266,7 +253,7 @@
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
org.nd4j
@@ -286,9 +273,6 @@
org.apache.maven.plugins
maven-surefire-plugin
-
- -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"
-
diff --git a/deeplearning4j/deeplearning4j-common-tests/pom.xml b/deeplearning4j/deeplearning4j-common-tests/pom.xml
index cce6ea55d..7e1f27e15 100644
--- a/deeplearning4j/deeplearning4j-common-tests/pom.xml
+++ b/deeplearning4j/deeplearning4j-common-tests/pom.xml
@@ -64,7 +64,7 @@
- test-nd4j-native
+ nd4j-tests-cpu
org.nd4j
@@ -75,7 +75,7 @@
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
org.nd4j
diff --git a/deeplearning4j/deeplearning4j-common/pom.xml b/deeplearning4j/deeplearning4j-common/pom.xml
index c63939b27..e2be6465f 100644
--- a/deeplearning4j/deeplearning4j-common/pom.xml
+++ b/deeplearning4j/deeplearning4j-common/pom.xml
@@ -56,10 +56,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml
index 655e60a8a..4fd587d9c 100644
--- a/deeplearning4j/deeplearning4j-core/pom.xml
+++ b/deeplearning4j/deeplearning4j-core/pom.xml
@@ -166,7 +166,7 @@
- test-nd4j-native
+ nd4j-tests-cpu
org.nd4j
@@ -177,7 +177,7 @@
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
org.nd4j
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java
index 0fe9528b8..0627eacf2 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java
@@ -23,7 +23,6 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.jupiter.api.Test;
-import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@@ -34,7 +33,6 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
-import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Early Termination Data Set Iterator Test")
class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java
index 6a953278b..929f802ff 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java
@@ -21,19 +21,16 @@ package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
-
+import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
-import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
+
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
-import org.junit.jupiter.api.DisplayName;
-import org.junit.jupiter.api.extension.ExtendWith;
-
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Early Termination Multi Data Set Iterator Test")
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java
index 023f35449..4dee21804 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java
@@ -34,7 +34,6 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
-import org.junit.rules.ExpectedException;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -46,7 +45,6 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
-import org.junit.jupiter.api.extension.ExtendWith;
@Disabled
@DisplayName("Attention Layer Test")
diff --git a/deeplearning4j/deeplearning4j-cuda/pom.xml b/deeplearning4j/deeplearning4j-cuda/pom.xml
index 3c12fbbc3..1555915d7 100644
--- a/deeplearning4j/deeplearning4j-cuda/pom.xml
+++ b/deeplearning4j/deeplearning4j-cuda/pom.xml
@@ -105,11 +105,12 @@
- test-nd4j-native
+ nd4j-tests-cpu
maven-surefire-plugin
+ true
true
@@ -118,7 +119,7 @@
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml
index 791cd923a..45ee5100b 100644
--- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml
+++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml
@@ -56,10 +56,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml
index 048d62fd0..748a10c50 100644
--- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml
+++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml
@@ -50,10 +50,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml
index 5e8d6561c..10ce9a8ce 100644
--- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml
+++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml
@@ -45,10 +45,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-data/pom.xml b/deeplearning4j/deeplearning4j-data/pom.xml
index 6792e9d38..5f047041b 100644
--- a/deeplearning4j/deeplearning4j-data/pom.xml
+++ b/deeplearning4j/deeplearning4j-data/pom.xml
@@ -54,10 +54,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml b/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml
index fc5cee1ac..cce784580 100644
--- a/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml
+++ b/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml
@@ -112,7 +112,7 @@
- test-nd4j-native
+ nd4j-tests-cpu
org.nd4j
@@ -123,7 +123,7 @@
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
org.nd4j
diff --git a/deeplearning4j/deeplearning4j-graph/pom.xml b/deeplearning4j/deeplearning4j-graph/pom.xml
index 164219a58..8ae897976 100644
--- a/deeplearning4j/deeplearning4j-graph/pom.xml
+++ b/deeplearning4j/deeplearning4j-graph/pom.xml
@@ -72,10 +72,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml
index 3f430ab04..3ff0353b3 100644
--- a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml
+++ b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml
@@ -306,7 +306,7 @@
- test-nd4j-native
+ nd4j-tests-cpu
org.nd4j
@@ -317,7 +317,7 @@
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
org.nd4j
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
index a4ea94d8b..dcadbfa19 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
@@ -101,10 +101,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/pom.xml
index 7c7773d6f..e1f0c35c8 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/pom.xml
+++ b/deeplearning4j/deeplearning4j-nlp-parent/pom.xml
@@ -49,10 +49,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-nn/pom.xml b/deeplearning4j/deeplearning4j-nn/pom.xml
index 62d092567..6ebce95d6 100644
--- a/deeplearning4j/deeplearning4j-nn/pom.xml
+++ b/deeplearning4j/deeplearning4j-nn/pom.xml
@@ -127,10 +127,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml
index 994364216..ed9625547 100644
--- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml
@@ -102,7 +102,7 @@
- test-nd4j-native
+ nd4j-tests-cpu
org.nd4j
@@ -113,7 +113,7 @@
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
org.nd4j
diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml
index 77e481c6a..09e9603c6 100644
--- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml
@@ -99,10 +99,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-scaleout/pom.xml b/deeplearning4j/deeplearning4j-scaleout/pom.xml
index 6cb37caa7..30758ee79 100644
--- a/deeplearning4j/deeplearning4j-scaleout/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/pom.xml
@@ -44,10 +44,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
index 850335cbf..431ffe764 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
@@ -89,10 +89,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml
index 9e6f92e6b..ba96a4b88 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml
@@ -88,10 +88,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml
index 4136e2a92..e60be88d2 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml
@@ -90,10 +90,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml
index 1068bda5c..7a328ca52 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml
@@ -105,10 +105,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml
index c74e3e94e..0147f87af 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml
@@ -182,10 +182,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml
index 3a96e8a4a..e5b5254d0 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml
@@ -77,10 +77,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml
index 137d78fce..040011ab8 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml
@@ -104,10 +104,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml
index aa75528fe..b02387920 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml
@@ -141,10 +141,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml
index 53d11e05a..aa0271686 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml
@@ -79,10 +79,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml
index b7924d582..a9df8ea56 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml
@@ -426,10 +426,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
\ No newline at end of file
diff --git a/deeplearning4j/deeplearning4j-ui-parent/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/pom.xml
index a48f7c43d..db3833dd6 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/pom.xml
@@ -44,10 +44,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/deeplearning4j-zoo/pom.xml b/deeplearning4j/deeplearning4j-zoo/pom.xml
index b93606710..e1508e08c 100644
--- a/deeplearning4j/deeplearning4j-zoo/pom.xml
+++ b/deeplearning4j/deeplearning4j-zoo/pom.xml
@@ -87,10 +87,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/deeplearning4j/dl4j-integration-tests/pom.xml b/deeplearning4j/dl4j-integration-tests/pom.xml
index 461d013a7..a491f38a7 100644
--- a/deeplearning4j/dl4j-integration-tests/pom.xml
+++ b/deeplearning4j/dl4j-integration-tests/pom.xml
@@ -117,10 +117,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
\ No newline at end of file
diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml
index 475b84d15..1212df5d6 100644
--- a/deeplearning4j/pom.xml
+++ b/deeplearning4j/pom.xml
@@ -143,6 +143,10 @@
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
org.apache.maven.plugins
maven-enforcer-plugin
@@ -158,7 +162,7 @@
${skipBackendChoice}
- test-nd4j-native,test-nd4j-cuda-11.0
+ nd4j-tests-cpu,nd4j-tests-cuda
false
@@ -227,43 +231,6 @@
-
-
-
- maven-surefire-plugin
- true
-
-
- true
- false
- -Dfile.encoding=UTF-8 -Xmx8g "
-
-
- *.java
- **/*.java
-
-
-
-
- org.apache.maven.surefire
- surefire-junit-platform
- ${maven-surefire-plugin.version}
-
-
-
-
- org.eclipse.m2e
- lifecycle-mapping
-
-
-
@@ -290,10 +257,10 @@
deeplearning4j-cuda
-
- test-nd4j-native
+ nd4j-tests-cpu
false
@@ -311,70 +278,10 @@
test
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
- true
-
-
- org.nd4j
- nd4j-native
- ${project.version}
-
-
- org.junit.jupiter
- junit-jupiter-engine
- ${junit.version}
-
-
- org.junit.jupiter
- junit-jupiter-params
- ${junit.version}
-
-
- org.apache.maven.surefire
- surefire-junit-platform
- ${maven-surefire-plugin.version}
-
-
-
-
-
-
- src/test/java
-
- *.java
- **/*.java
- **/Test*.java
- **/*Test.java
- **/*TestCase.java
-
- org.junit.jupiter:junit-jupiter-engine
-
-
- org.nd4j.linalg.cpu.nativecpu.CpuBackend
-
-
- org.nd4j.linalg.cpu.nativecpu.CpuBackend
-
-
-
-
-
-
-
-
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
false
@@ -392,43 +299,6 @@
test
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
- ${maven-surefire-plugin.version}
-
-
-
- src/test/java
-
- *.java
- **/*.java
- **/Test*.java
- **/*Test.java
- **/*TestCase.java
-
- org.junit.jupiter:junit-jupiter
-
-
- org.nd4j.linalg.jcublas.JCublasBackend
-
-
- org.nd4j.linalg.jcublas.JCublasBackend
-
-
-
- -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"
-
-
-
-
-
diff --git a/libnd4j/test-results.txt b/libnd4j/test-results.txt
index aee60b267..84816b6f5 100644
--- a/libnd4j/test-results.txt
+++ b/libnd4j/test-results.txt
@@ -5,7 +5,7 @@ Linux
[INFO] Total time: 14.610 s
[INFO] Finished at: 2021-03-06T15:35:28+09:00
[INFO] ------------------------------------------------------------------------
-[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist.
+[WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
[ERROR]
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
@@ -749,7 +749,7 @@ make[1]: Leaving directory '/c/Users/agibs/Documents/GitHub/eclipse-deeplearning
[INFO] Total time: 15.482 s
[INFO] Finished at: 2021-03-06T15:27:35+09:00
[INFO] ------------------------------------------------------------------------
-[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist.
+[WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
[ERROR]
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
diff --git a/nd4j/nd4j-backends/nd4j-tests/ops-added-old.txt b/nd4j/nd4j-backends/nd4j-tests/ops-added-old.txt
new file mode 100644
index 000000000..84cf4d764
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-tests/ops-added-old.txt
@@ -0,0 +1,19 @@
+Const,in_0
+Const,while/Const
+Const,while/add/y
+Identity,in_0/read
+Enter,while/Enter
+Enter,while/Enter_1
+Merge,while/Merge
+Merge,while/Merge_1
+Less,while/Less
+LoopCond,while/LoopCond
+Switch,while/Switch
+Switch,while/Switch_1
+Identity,while/Identity
+Exit,while/Exit
+Identity,while/Identity_1
+Exit,while/Exit_1
+Add,while/add
+NextIteration,while/NextIteration_1
+NextIteration,while/NextIteration
diff --git a/nd4j/nd4j-backends/nd4j-tests/ops-imported-old.txt b/nd4j/nd4j-backends/nd4j-tests/ops-imported-old.txt
new file mode 100644
index 000000000..f4bde2724
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-tests/ops-imported-old.txt
@@ -0,0 +1,16 @@
+Identity,in_0/read
+Enter,while/Enter
+Enter,while/Enter_1
+Merge,while/Merge
+Merge,while/Merge_1
+Less,while/Less
+LoopCond,while/LoopCond
+Switch,while/Switch
+Switch,while/Switch_1
+Identity,while/Identity
+Exit,while/Exit
+Identity,while/Identity_1
+Exit,while/Exit_1
+Add,while/add
+NextIteration,while/NextIteration_1
+NextIteration,while/NextIteration
diff --git a/nd4j/nd4j-backends/nd4j-tests/ops-removed-old.txt b/nd4j/nd4j-backends/nd4j-tests/ops-removed-old.txt
new file mode 100644
index 000000000..201dc67b4
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-tests/ops-removed-old.txt
@@ -0,0 +1,19 @@
+in_0
+while/Const
+while/add/y
+in_0/read
+while/Enter
+while/Enter_1
+while/Merge
+while/Merge_1
+while/Less
+while/LoopCond
+while/Switch
+while/Switch_1
+while/Identity
+while/Exit
+while/Identity_1
+while/Exit_1
+while/add
+while/NextIteration_1
+while/NextIteration
diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml
index 0d55475e6..d70eb3ced 100644
--- a/nd4j/nd4j-backends/nd4j-tests/pom.xml
+++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml
@@ -303,7 +303,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
- -Dfile.encoding=UTF-8 "
+ -Dfile.encoding=UTF-8
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java
index ea931b3a3..d11881051 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java
@@ -27,6 +27,7 @@ import java.util.List;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable;
@@ -482,7 +483,7 @@ public class LayerOpValidation extends BaseOpValidation {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testConv3d(Nd4jBackend backend) {
+ public void testConv3d(Nd4jBackend backend, TestInfo testInfo) {
//Pooling3d, Conv3D, batch norm
Nd4j.getRandom().setSeed(12345);
@@ -573,7 +574,7 @@ public class LayerOpValidation extends BaseOpValidation {
tc.testName(msg);
String error = OpValidation.validate(tc);
if (error != null) {
- failed.add(name);
+ failed.add(testInfo.getTestMethod().get().getName());
}
}
}
@@ -1353,7 +1354,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err, err);
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
int nIn = 3;
@@ -1382,7 +1384,8 @@ public class LayerOpValidation extends BaseOpValidation {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345);
@@ -1405,7 +1408,8 @@ public class LayerOpValidation extends BaseOpValidation {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java
index 2654caf02..591898055 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java
@@ -22,6 +22,7 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
+import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@@ -664,6 +665,7 @@ public class MiscOpValidation extends BaseOpValidation {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ @Disabled
public void testMmulGradientManual(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java
index edf5859fa..3681c77b8 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java
@@ -69,7 +69,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@AfterEach
- public void tearDown(Nd4jBackend backend) {
+ public void tearDown() {
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java
index a015bfec7..b7e3a6551 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java
@@ -28,6 +28,7 @@ import lombok.val;
import org.apache.commons.math3.linear.LUDecomposition;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
@@ -83,7 +84,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testConcat(Nd4jBackend backend) {
+ public void testConcat(Nd4jBackend backend, TestInfo testInfo) {
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
int[] concatDim = new int[]{0, 0, 0};
List> origShapes = new ArrayList<>();
@@ -115,7 +116,7 @@ public class ShapeOpValidation extends BaseOpValidation {
String error = OpValidation.validate(tc);
if(error != null){
- failed.add(name);
+ failed.add(testInfo.getTestMethod().get().getName());
}
}
@@ -285,7 +286,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testSqueezeGradient(Nd4jBackend backend) {
+ public void testSqueezeGradient(Nd4jBackend backend,TestInfo testInfo) {
val origShape = new long[]{3, 4, 5};
List failed = new ArrayList<>();
@@ -339,7 +340,7 @@ public class ShapeOpValidation extends BaseOpValidation {
String error = OpValidation.validate(tc, true);
if(error != null){
- failed.add(name);
+ failed.add(testInfo.getTestMethod().get().getName());
}
}
}
@@ -580,8 +581,9 @@ public class ShapeOpValidation extends BaseOpValidation {
return Long.MAX_VALUE;
}
- @Test()
- public void testStack(Nd4jBackend backend) {
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ public void testStack(Nd4jBackend backend,TestInfo testInfo) {
Nd4j.getRandom().setSeed(12345);
List failed = new ArrayList<>();
@@ -661,7 +663,7 @@ public class ShapeOpValidation extends BaseOpValidation {
String error = OpValidation.validate(tc);
if(error != null){
- failed.add(name);
+ failed.add(testInfo.getTestMethod().get().getName());
}
}
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java
index 4ff306796..42f93b98e 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java
@@ -72,6 +72,8 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
+ @TempDir Path testDir;
+
@Override
public char ordering(){
@@ -82,7 +84,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testBasic(Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create();
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
@@ -121,7 +123,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
int numOutputs = fg.outputsLength();
List outputs = new ArrayList<>(numOutputs);
- for( int i=0; i expTPR = new HashMap<>();
double totalPositives = 2.0;
@@ -251,27 +252,27 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
//Same as first test...
//2 outputs here - probability distribution over classes (softmax)
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
- {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
- {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
+ {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
+ {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
- {0, 1}, {0, 1}});
+ {0, 1}, {0, 1}});
INDArray predictions3d = Nd4j.create(2, 2, 5);
INDArray firstTSp =
- predictions3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
+ predictions3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
assertArrayEquals(new long[] {5, 2}, firstTSp.shape());
firstTSp.assign(predictions2d.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
INDArray secondTSp =
- predictions3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
+ predictions3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
assertArrayEquals(new long[] {5, 2}, secondTSp.shape());
secondTSp.assign(predictions2d.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
@@ -299,23 +300,23 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax)
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
- {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
- {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
+ {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
+ {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
- {0, 1}, {0, 1}});
+ {0, 1}, {0, 1}});
//Create time series data... first time series: length 4. Second time series: length 6
INDArray predictions3d = Nd4j.create(2, 2, 6);
INDArray tad = predictions3d.tensorAlongDimension(0, 1, 2).transpose();
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
- .assign(predictions2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
+ .assign(predictions2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
tad = predictions3d.tensorAlongDimension(1, 1, 2).transpose();
tad.assign(predictions2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
@@ -324,7 +325,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
INDArray labels3d = Nd4j.create(2, 2, 6);
tad = labels3d.tensorAlongDimension(0, 1, 2).transpose();
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
- .assign(actual2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
+ .assign(actual2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
tad = labels3d.tensorAlongDimension(1, 1, 2).transpose();
tad.assign(actual2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
@@ -350,7 +351,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
@@ -381,7 +382,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCompare2Vs3Classes(Nd4jBackend backend) {
@@ -431,7 +432,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMerging(Nd4jBackend backend) {
int nArrays = 10;
@@ -477,7 +478,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMerging2(Nd4jBackend backend) {
int nArrays = 10;
@@ -523,7 +524,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMultiMerging(Nd4jBackend backend) {
@@ -572,7 +573,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAUCPrecisionRecall(Nd4jBackend backend) {
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
@@ -620,7 +621,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocAucExact(Nd4jBackend backend) {
@@ -681,20 +682,20 @@ public class ROCTest extends BaseNd4jTestWithBackends {
*/
double[] p = new double[] {0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503, 0.5955447, 0.96451452,
- 0.6531771, 0.74890664, 0.65356987, 0.74771481, 0.96130674, 0.0083883, 0.10644438, 0.29870371,
- 0.65641118, 0.80981255, 0.87217591, 0.9646476, 0.72368535, 0.64247533, 0.71745362, 0.46759901,
- 0.32558468, 0.43964461, 0.72968908, 0.99401459, 0.67687371, 0.79082252, 0.17091426};
+ 0.6531771, 0.74890664, 0.65356987, 0.74771481, 0.96130674, 0.0083883, 0.10644438, 0.29870371,
+ 0.65641118, 0.80981255, 0.87217591, 0.9646476, 0.72368535, 0.64247533, 0.71745362, 0.46759901,
+ 0.32558468, 0.43964461, 0.72968908, 0.99401459, 0.67687371, 0.79082252, 0.17091426};
double[] l = new double[] {1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
- 0, 1};
+ 0, 1};
double[] fpr_skl = new double[] {0.0, 0.0, 0.15789474, 0.15789474, 0.31578947, 0.31578947, 0.52631579,
- 0.52631579, 0.68421053, 0.68421053, 0.84210526, 0.84210526, 0.89473684, 0.89473684, 1.0};
+ 0.52631579, 0.68421053, 0.68421053, 0.84210526, 0.84210526, 0.89473684, 0.89473684, 1.0};
double[] tpr_skl = new double[] {0.0, 0.09090909, 0.09090909, 0.18181818, 0.18181818, 0.36363636, 0.36363636,
- 0.45454545, 0.45454545, 0.72727273, 0.72727273, 0.90909091, 0.90909091, 1.0, 1.0};
+ 0.45454545, 0.45454545, 0.72727273, 0.72727273, 0.90909091, 0.90909091, 1.0, 1.0};
//Note the change to the last value: same TPR and FPR at 0.0083883 and 0.0 -> we add the 0.0 threshold edge case + combine with the previous one. Same result
double[] thr_skl = new double[] {1.0, 0.99401459, 0.96130674, 0.92961609, 0.79082252, 0.74771481, 0.67687371,
- 0.65641118, 0.64247533, 0.46759901, 0.31637555, 0.20456028, 0.18391881, 0.17091426, 0.0};
+ 0.65641118, 0.64247533, 0.46759901, 0.31637555, 0.20456028, 0.18391881, 0.17091426, 0.0};
INDArray prob = Nd4j.create(p, new int[] {30, 1});
INDArray label = Nd4j.create(l, new int[] {30, 1});
@@ -784,7 +785,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
@@ -797,7 +798,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
double[] threshold = new double[101];
@@ -814,15 +815,15 @@ public class ROCTest extends BaseNd4jTestWithBackends {
PrecisionRecallCurve prc = new PrecisionRecallCurve(threshold, precision, recall, null, null, null, -1);
PrecisionRecallCurve.Point[] points = new PrecisionRecallCurve.Point[] {
- //Test exact:
- prc.getPointAtThreshold(0.05), prc.getPointAtPrecision(0.05), prc.getPointAtRecall(1 - 0.05),
+ //Test exact:
+ prc.getPointAtThreshold(0.05), prc.getPointAtPrecision(0.05), prc.getPointAtRecall(1 - 0.05),
- //Test approximate (point doesn't exist exactly). When it doesn't exist:
- //Threshold: lowest threshold equal to or exceeding the specified threshold value
- //Precision: lowest threshold equal to or exceeding the specified precision value
- //Recall: highest threshold equal to or exceeding the specified recall value
- prc.getPointAtThreshold(0.0495), prc.getPointAtPrecision(0.0495),
- prc.getPointAtRecall(1 - 0.0505)};
+ //Test approximate (point doesn't exist exactly). When it doesn't exist:
+ //Threshold: lowest threshold equal to or exceeding the specified threshold value
+ //Precision: lowest threshold equal to or exceeding the specified precision value
+ //Recall: highest threshold equal to or exceeding the specified recall value
+ prc.getPointAtThreshold(0.0495), prc.getPointAtPrecision(0.0495),
+ prc.getPointAtRecall(1 - 0.0505)};
@@ -834,7 +835,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
//Sanity check: values calculated from the confusion matrix should match the PR curve values
@@ -843,7 +844,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
ROC r = new ROC(0, removeRedundantPts);
INDArray labels = Nd4j.getExecutioner()
- .exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE,100, 1), 0.5));
+ .exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE,100, 1), 0.5));
INDArray probs = Nd4j.rand(100, 1);
r.eval(labels, probs);
@@ -874,7 +875,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocMerge(){
Nd4j.getRandom().setSeed(12345);
@@ -919,7 +920,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
assertEquals(auprc, auprcAct, 1e-6);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocMultiMerge(){
Nd4j.getRandom().setSeed(12345);
@@ -931,9 +932,9 @@ public class ROCTest extends BaseNd4jTestWithBackends {
int nOut = 5;
Random r = new Random(12345);
- for( int i=0; i<10; i++ ){
+ for( int i = 0; i < 10; i++ ){
INDArray labels = Nd4j.zeros(3, nOut);
- for( int j=0; j<3; j++ ){
+ for( int j = 0; j < 3; j++) {
labels.putScalar(j, r.nextInt(nOut), 1.0 );
}
INDArray out = Nd4j.rand(3, nOut);
@@ -956,7 +957,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
roc1.merge(roc2);
- for( int i=0; i {
int specCols = 5;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java
index e49e91937..aeb8a4705 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java
@@ -152,7 +152,7 @@ public class LoneTest extends BaseNd4jTestWithBackends {
public void maskWhenMerge(Nd4jBackend backend) {
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
- List dataSetList = new ArrayList();
+ List dataSetList = new ArrayList<>();
dataSetList.add(dsA);
dataSetList.add(dsB);
DataSet fullDataSet = DataSet.merge(dataSetList);
@@ -175,7 +175,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
// System.out.println(b);
}
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
//broken at a threshold
public void testArgMax(Nd4jBackend backend) {
int max = 63;
@@ -263,7 +264,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
// log.info("p50: {}; avg: {};", times.get(times.size() / 2), time);
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void checkIllegalElementOps(Nd4jBackend backend) {
assertThrows(Exception.class,() -> {
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
@@ -328,13 +330,13 @@ public class LoneTest extends BaseNd4jTestWithBackends {
reshaped.getDouble(i);
}
for (int j=0;j {
INDArray arr = Nd4j.create(4, 5);
@@ -2357,7 +2361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testTensorDot(Nd4jBackend backend) {
INDArray oneThroughSixty = Nd4j.arange(60).reshape(3, 4, 5).castTo(DataType.DOUBLE);
@@ -3051,10 +3056,10 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
public void testMeans(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray mean1 = a.mean(1);
- assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage());
- assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage());
- assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage());
- assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage(backend));
+ assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage(backend));
+ assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
+ assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
}
@@ -3063,9 +3068,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSums(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
- assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage());
- assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage());
- assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage(backend));
+ assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage(backend));
+ assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
}
@@ -3438,7 +3443,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void largeInstantiation(Nd4jBackend backend) {
Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024 - 1)); // Still works; this can even be called as often as I want, allowing me even to spill over on disk
@@ -3487,7 +3493,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(cSum, fSum); //Expect: 4,6. Getting [4, 4] for f order
}
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled //not relevant anymore
public void testAssignMixedC(Nd4jBackend backend) {
int[] shape1 = {3, 2, 2, 2, 2, 2};
@@ -3787,7 +3794,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, result);
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2});
@@ -3795,7 +3803,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
});
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation2(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2});
@@ -3803,7 +3812,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
});
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation3(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10});
@@ -3811,7 +3821,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
});
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation4(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3});
@@ -3819,7 +3830,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
});
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation5(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e');
@@ -4975,7 +4987,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTadReduce3_5(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
INDArray initial = Nd4j.create(5, 10);
@@ -6004,7 +6017,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testLogExpSum1(Nd4jBackend backend) {
INDArray matrix = Nd4j.create(3, 3);
@@ -6019,7 +6033,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testLogExpSum2(Nd4jBackend backend) {
INDArray row = Nd4j.create(new double[]{1, 2, 3});
@@ -6246,7 +6261,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReshapeFailure(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2);
@@ -6345,7 +6361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertArrayEquals(new long[]{3, 2}, newShape.shape());
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTranspose1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
@@ -6360,7 +6377,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTranspose2(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
val scalar = Nd4j.scalar(2.f);
@@ -6375,7 +6393,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
//@Disabled
public void testMatmul_128by256(Nd4jBackend backend) {
val mA = Nd4j.create(128, 156).assign(1.0f);
@@ -6647,7 +6666,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp1, out1);
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBadReduce3Call(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
val x = Nd4j.create(400,20);
@@ -7392,8 +7412,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(ez, z);
}
- @Test()
- public void testBroadcastInvalid(){
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ public void testBroadcastInvalid() {
assertThrows(IllegalStateException.class,() -> {
INDArray arr1 = Nd4j.ones(3,4,1);
@@ -7656,7 +7677,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp, array);
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScatterUpdateShortcut_f1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
val array = Nd4j.create(DataType.FLOAT, 5, 2);
@@ -8041,7 +8063,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp, out); //Failing here
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsFailure(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
val idxs = new int[]{0,2,3,4};
@@ -8144,7 +8167,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp1, out1); //This is OK
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutRowValidation(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
val matrix = Nd4j.create(5, 10);
@@ -8155,7 +8179,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutColumnValidation(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
val matrix = Nd4j.create(5, 10);
@@ -8236,7 +8261,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testScalarEq(){
+ public void testScalarEq(Nd4jBackend backend){
INDArray scalarRank2 = Nd4j.scalar(10.0).reshape(1,1);
INDArray scalarRank1 = Nd4j.scalar(10.0).reshape(1);
INDArray scalarRank0 = Nd4j.scalar(10.0);
@@ -8273,7 +8298,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testType1(@TempDir Path testDir) throws IOException {
+ @Disabled
+ public void testType1(Nd4jBackend backend) throws IOException {
for (int i = 0; i < 10; ++i) {
INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100});
File dir = testDir.toFile();
@@ -8295,7 +8321,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testOnes(){
+ public void testOnes(Nd4jBackend backend){
INDArray arr = Nd4j.ones();
INDArray arr2 = Nd4j.ones(DataType.LONG);
assertEquals(0, arr.rank());
@@ -8306,7 +8332,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testZeros(){
+ public void testZeros(Nd4jBackend backend){
INDArray arr = Nd4j.zeros();
INDArray arr2 = Nd4j.zeros(DataType.LONG);
assertEquals(0, arr.rank());
@@ -8317,7 +8343,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testType2(@TempDir Path testDir) throws IOException {
+ @Disabled
+ public void testType2(Nd4jBackend backend) throws IOException {
for (int i = 0; i < 10; ++i) {
INDArray in1 = Nd4j.ones(DataType.UINT16);
File dir = testDir.toFile();
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java
index b44170433..03eacb890 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ToStringTest.java
@@ -23,6 +23,7 @@ package org.nd4j.linalg;
import static org.junit.jupiter.api.Assertions.assertEquals;
import lombok.extern.slf4j.Slf4j;
+import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@@ -58,11 +59,12 @@ public class ToStringTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testToStringScalars(){
+ @Disabled
+ public void testToStringScalars(Nd4jBackend backend){
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};
- for(int dt=0; dt<5; dt++ ) {
+ for(int dt = 0; dt < 5; dt++) {
for (int i = 0; i < 5; i++) {
long[] shape = ArrayUtil.nTimes(i, 1L);
INDArray scalar = Nd4j.scalar(1.0f).castTo(dataTypes[dt]).reshape(shape);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java
index c0a387ad3..2d7a56eae 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNDArrayCreation.java
@@ -64,7 +64,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
}
- @Test
@Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@@ -79,7 +78,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
}
- @Test
@Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@@ -100,7 +98,8 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
}
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCreateNpy3(Nd4jBackend backend) throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
assertEquals(8, arrCreate.length());
@@ -111,8 +110,9 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
assertEquals(arrCreate.data().address(), pointer.address());
}
- @Test
@Disabled // this is endless test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEndlessAllocation(Nd4jBackend backend) {
Nd4j.getEnvironment().setMaxSpecialMemory(1);
while (true) {
@@ -121,9 +121,10 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
}
}
- @Test
@Disabled("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time")
- public void testAllocationLimits() throws Exception {
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ public void testAllocationLimits(Nd4jBackend backend) throws Exception {
Nd4j.create(1);
val origDeviceLimit = Nd4j.getEnvironment().getDeviceLimit(0);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java
index 258177261..39eb7dfd0 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/TestNamespaces.java
@@ -20,7 +20,6 @@
package org.nd4j.linalg.api;
-import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java
index 1584b72dc..5ececc0d8 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/blas/Level1Test.java
@@ -59,7 +59,7 @@ public class Level1Test extends BaseNd4jTestWithBackends {
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray row = matrix.getRow(1);
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
- assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage());
+ assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage(backend));
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java
index 3e7971eed..5735852b5 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/DataTypeValidationTests.java
@@ -70,8 +70,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
/**
* Testing level1 blas
*/
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
@@ -89,8 +88,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
/**
* Testing level2 blas
*/
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation2(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> {
@@ -109,8 +107,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
/**
* Testing level3 blas
*/
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation3(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java
index 5f4fd3665..8af34a323 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/FloatDataBufferTest.java
@@ -88,7 +88,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
float[] d2 = d.asFloat();
- assertArrayEquals( d1, d2, 1e-1f,getFailureMessage());
+ assertArrayEquals( d1, d2, 1e-1f,getFailureMessage(backend));
}
@@ -146,7 +146,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
d.put(0, 0.0);
float[] result = new float[] {0, 2, 3, 4};
d1 = d.asFloat();
- assertArrayEquals(d1, result, 1e-1f,getFailureMessage());
+ assertArrayEquals(d1, result, 1e-1f,getFailureMessage(backend));
}
@@ -156,12 +156,12 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(0, 3);
float[] data = new float[] {1, 2, 3};
- assertArrayEquals(get, data, 1e-1f,getFailureMessage());
+ assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
float[] get2 = buffer.asFloat();
float[] allData = buffer.getFloatsAt(0, (int) buffer.length());
- assertArrayEquals(get2, allData, 1e-1f,getFailureMessage());
+ assertArrayEquals(get2, allData, 1e-1f,getFailureMessage(backend));
}
@@ -173,13 +173,13 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(1, 3);
float[] data = new float[] {2, 3, 4};
- assertArrayEquals(get, data, 1e-1f,getFailureMessage());
+ assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
float[] allButLast = new float[] {2, 3, 4, 5};
float[] allData = buffer.getFloatsAt(1, (int) buffer.length());
- assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage());
+ assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage(backend));
}
@@ -190,7 +190,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
public void testAsBytes(Nd4jBackend backend) {
INDArray arr = Nd4j.create(5);
byte[] d = arr.data().asBytes();
- assertEquals(4 * 5, d.length,getFailureMessage());
+ assertEquals(4 * 5, d.length,getFailureMessage(backend));
INDArray rand = Nd4j.rand(3, 3);
rand.data().asBytes();
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java
index 45ef02238..fbcbd656a 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java
@@ -20,26 +20,18 @@
package org.nd4j.linalg.api.indexing;
-import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
-
import org.nd4j.common.base.Preconditions;
+import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
-import org.nd4j.linalg.indexing.INDArrayIndex;
-import org.nd4j.linalg.indexing.IntervalIndex;
-import org.nd4j.linalg.indexing.NDArrayIndex;
-import org.nd4j.linalg.indexing.NDArrayIndexAll;
-import org.nd4j.linalg.indexing.NewAxis;
-import org.nd4j.linalg.indexing.PointIndex;
-import org.nd4j.linalg.indexing.SpecifiedIndex;
+import org.nd4j.linalg.indexing.*;
import org.nd4j.linalg.ops.transforms.Transforms;
-import org.nd4j.common.util.ArrayUtil;
import java.util.Arrays;
import java.util.Random;
@@ -56,22 +48,22 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testNegativeBounds() {
- INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
- INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
- INDArray get = arr.get(NDArrayIndex.all(),interval);
- INDArray assertion = Nd4j.create(new double[][]{
- {1,2,3},
- {6,7,8}
- });
- assertEquals(assertion,get);
+ public void testNegativeBounds(Nd4jBackend backend) {
+ INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
+ INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
+ INDArray get = arr.get(NDArrayIndex.all(),interval);
+ INDArray assertion = Nd4j.create(new double[][]{
+ {1,2,3},
+ {6,7,8}
+ });
+ assertEquals(assertion,get);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testNewAxis() {
+ public void testNewAxis(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
long[] shapeAssertion = {3, 2, 1, 1, 2};
@@ -79,9 +71,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void broadcastBug() {
+ public void broadcastBug(Nd4jBackend backend) {
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
@@ -91,9 +83,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testIntervalsIn3D() {
+ public void testIntervalsIn3D(Nd4jBackend backend) {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
INDArray rest = arr.get(interval(1, 2), interval(0, 2), interval(0, 2));
@@ -101,9 +93,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testSmallInterval() {
+ public void testSmallInterval(Nd4jBackend backend) {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
INDArray rest = arr.get(interval(1, 2), all(), all());
@@ -111,9 +103,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testAllWithNewAxisAndInterval() {
+ public void testAllWithNewAxisAndInterval(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
@@ -121,9 +113,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion2, get2);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testAllWithNewAxisInMiddle() {
+ public void testAllWithNewAxisInMiddle(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
@@ -131,20 +123,20 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion2, get2);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testAllWithNewAxis() {
+ public void testAllWithNewAxis(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray get = arr.get(newAxis(), all(), point(1));
INDArray assertion = Nd4j.create(new double[][] {{4, 5, 6}, {10, 11, 12}, {16, 17, 18}, {22, 23, 24}})
- .reshape(1, 4, 3);
+ .reshape(1, 4, 3);
assertEquals(assertion, get);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testIndexingWithMmul() {
+ public void testIndexingWithMmul(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
// System.out.println(b);
@@ -154,9 +146,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, c);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testPointPointInterval() {
+ public void testPointPointInterval(Nd4jBackend backend) {
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
INDArray assertion = Nd4j.create(new double[][] {{5, 6}, {8, 9}});
@@ -164,9 +156,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, get);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testIntervalLowerBound() {
+ public void testIntervalLowerBound(Nd4jBackend backend) {
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}});
@@ -176,9 +168,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testGetPointRowVector() {
+ public void testGetPointRowVector(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
INDArray arr2 = arr.get(point(0), interval(0, 100));
@@ -187,9 +179,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testSpecifiedIndexVector() {
+ public void testSpecifiedIndexVector(Nd4jBackend backend) {
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
INDArray get = rootMatrix.get(all(), new SpecifiedIndex(0, 2));
@@ -205,9 +197,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testPutRowIndexing() {
+ public void testPutRowIndexing(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(1, 10);
INDArray row = Nd4j.create(1, 10);
@@ -216,9 +208,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(arr, row);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testVectorIndexing2() {
+ public void testVectorIndexing2(Nd4jBackend backend) {
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
INDArray assertion = Nd4j.create(new double[] {2, 4});
assertEquals(assertion, wholeVector);
@@ -232,9 +224,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testOffsetsC() {
+ public void testOffsetsC(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
assertEquals(3, NDArrayIndex.offset(arr, point(1), point(1)));
@@ -249,9 +241,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testIndexFor() {
+ public void testIndexFor(Nd4jBackend backend) {
long[] shape = {1, 2};
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
for (int i = 0; i < indexes.length; i++) {
@@ -259,9 +251,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testGetScalar() {
+ public void testGetScalar(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(point(1));
assertTrue(d.isScalar());
@@ -269,26 +261,26 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testVectorIndexing() {
+ public void testVectorIndexing(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
INDArray viewTest = arr.get(point(0), interval(1, 5));
assertEquals(assertion, viewTest);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testNegativeIndices() {
+ public void testNegativeIndices(Nd4jBackend backend) {
INDArray test = Nd4j.create(10, 10, 10);
test.putScalar(new int[] {0, 0, -1}, 1.0);
assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber());
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testGetIndices2d() {
+ public void testGetIndices2d(Nd4jBackend backend) {
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
INDArray firstRow = twoByTwo.getRow(0);
INDArray secondRow = twoByTwo.getRow(1);
@@ -305,9 +297,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testGetRow() {
+ public void testGetRow(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
int[] toGet = {0, 1};
@@ -323,9 +315,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testGetRowEdgeCase() {
+ public void testGetRowEdgeCase(Nd4jBackend backend) {
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
INDArray get = rowVec.getRow(0); //Returning shape [1,1]
@@ -333,9 +325,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(rowVec, get);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testGetColumnEdgeCase() {
+ public void testGetColumnEdgeCase(Nd4jBackend backend) {
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray get = colVec.getColumn(0); //Returning shape [1,1]
@@ -343,9 +335,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(colVec, get);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testConcatColumns() {
+ public void testConcatColumns(Nd4jBackend backend) {
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
INDArray concat = Nd4j.concat(1, input1, input2);
@@ -353,18 +345,18 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, concat);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testGetIndicesVector() {
+ public void testGetIndicesVector(Nd4jBackend backend) {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3});
INDArray result = line.get(point(0), interval(1, 3));
assertEquals(test, result);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testArangeMul() {
+ public void testArangeMul(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = interval(0, 2);
INDArray get = arange.get(index, index);
@@ -374,7 +366,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, mul);
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIndexingThorough(){
long[] fullShape = {3,4,5,6,7};
@@ -575,7 +567,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
return d;
}
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void debugging(){
long[] inShape = {3,4};
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java
index 694016812..9e5491098 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxt.java
@@ -46,12 +46,13 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j
-
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
+ @TempDir Path testDir;
+
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void compareAfterWrite(Nd4jBackend backend) throws Exception {
int [] ranksToCheck = new int[] {0,1,2,3,4};
for (int i = 0; i < ranksToCheck.length; i++) {
// log.info("Checking read write arrays with rank " + ranksToCheck[i]);
@@ -82,7 +83,7 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testNd4jReadWriteText(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testNd4jReadWriteText(Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
int count = 0;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxtC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxtC.java
index f8dcfda03..861412773 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxtC.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/ndarray/TestNdArrReadWriteTxtC.java
@@ -38,11 +38,11 @@ import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays;
@Slf4j
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
-
+ @TempDir Path testDir;
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void compareAfterWrite(Nd4jBackend backend) throws Exception {
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
for (int i = 0; i < ranksToCheck.length; i++) {
log.info("Checking read write arrays with rank " + ranksToCheck[i]);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java
index 22f17f103..ccbd72bc7 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java
@@ -22,6 +22,7 @@ package org.nd4j.linalg.broadcast;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
+import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@@ -135,7 +136,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
assertEquals(e, z);
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_1(Nd4jBackend backend) {
@@ -146,7 +146,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
});
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_2(Nd4jBackend backend) {
@@ -158,7 +157,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_3(Nd4jBackend backend) {
@@ -170,16 +168,15 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ @Disabled
public void basicBroadcastFailureTest_4(Nd4jBackend backend) {
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
val z = x.addi(y);
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_5(Nd4jBackend backend) {
@@ -191,7 +188,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_6(Nd4jBackend backend) {
@@ -249,9 +245,9 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
assertEquals(y, z);
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ @Disabled
public void emptyBroadcastTest_2(Nd4jBackend backend) {
val x = Nd4j.create(DataType.FLOAT, 1, 2);
val y = Nd4j.create(DataType.FLOAT, 0, 2);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java
index 1f1ccd430..75164f89b 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/compression/CompressionMagicTests.java
@@ -37,7 +37,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class CompressionMagicTests extends BaseNd4jTestWithBackends {
@BeforeEach
- public void setUp(Nd4jBackend backend) {
+ public void setUp() {
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java
index e39678f4f..4809aa379 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java
@@ -48,6 +48,7 @@ import java.util.Set;
public class DeconvTests extends BaseNd4jTestWithBackends {
+ @TempDir Path testDir;
@Override
public char ordering() {
@@ -56,7 +57,7 @@ public class DeconvTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void compareKeras(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void compareKeras(Nd4jBackend backend) throws Exception {
File newFolder = testDir.toFile();
new ClassPathResource("keras/deconv/").copyDirectory(newFolder);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java
index 59ef09082..92d274d24 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java
@@ -99,7 +99,8 @@ public class SpecialTests extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScalarShuffle1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
List listData = new ArrayList<>();
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java
index a6a0ab8a6..4ed12dc0e 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java
@@ -195,7 +195,8 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
assertEquals(exp, arrayX);
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInplaceOp1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
val arrayX = Nd4j.create(10, 10);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java
index 056bc7ba3..55ae9457b 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java
@@ -41,10 +41,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
+ @TempDir Path testDir;
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testBalance(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testBalance(Nd4jBackend backend) throws Exception {
DataSetIterator iterator = new IrisDataSetIterator(10, 150);
File minibatches = new File(testDir.toFile(),"mini-batch-dir");
@@ -62,7 +63,7 @@ public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testMiniBatchBalanced(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testMiniBatchBalanced(Nd4jBackend backend) throws Exception {
int miniBatchSize = 100;
DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java
index a0e14ac16..f16dccd08 100755
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/DataSetTest.java
@@ -51,8 +51,10 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
@Slf4j
public class DataSetTest extends BaseNd4jTestWithBackends {
-
- @ParameterizedTest
+
+ @TempDir Path testDir;
+
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testViewIterator(Nd4jBackend backend) {
DataSetIterator iter = new ViewIterator(new IrisDataSetIterator(150, 150).next(), 10);
@@ -106,9 +108,9 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testSplitTestAndTrain (Nd4jBackend backend) {
+ public void testSplitTestAndTrain(Nd4jBackend backend) {
INDArray labels = FeatureUtil.toOutcomeMatrix(new int[] {0, 0, 0, 0, 0, 0, 0, 0}, 1);
DataSet data = new DataSet(Nd4j.rand(8, 1), labels);
@@ -116,7 +118,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
assertEquals(train.getTrain().getLabels().length(), 6);
SplitTestAndTrain train2 = data.splitTestAndTrain(6, new Random(1));
- assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage());
+ assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage(backend));
DataSet x0 = new IrisDataSetIterator(150, 150).next();
SplitTestAndTrain testAndTrain = x0.splitTestAndTrain(10);
@@ -144,7 +146,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
SplitTestAndTrain testAndTrainRng = x2.splitTestAndTrain(10, rngHere);
assertArrayEquals(testAndTrainRng.getTrain().getFeatures().shape(),
- testAndTrain.getTrain().getFeatures().shape());
+ testAndTrain.getTrain().getFeatures().shape());
assertEquals(testAndTrainRng.getTrain().getFeatures(), testAndTrain.getTrain().getFeatures());
assertEquals(testAndTrainRng.getTrain().getLabels(), testAndTrain.getTrain().getLabels());
@@ -154,13 +156,13 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLabelCounts(Nd4jBackend backend) {
DataSet x0 = new IrisDataSetIterator(150, 150).next();
- assertEquals(0, x0.get(0).outcome(),getFailureMessage());
- assertEquals( 0, x0.get(1).outcome(),getFailureMessage());
- assertEquals(2, x0.get(149).outcome(),getFailureMessage());
+ assertEquals(0, x0.get(0).outcome(),getFailureMessage(backend));
+ assertEquals( 0, x0.get(1).outcome(),getFailureMessage(backend));
+ assertEquals(2, x0.get(149).outcome(),getFailureMessage(backend));
Map counts = x0.labelCounts();
- assertEquals(50, counts.get(0), 1e-1,getFailureMessage());
- assertEquals(50, counts.get(1), 1e-1,getFailureMessage());
- assertEquals(50, counts.get(2), 1e-1,getFailureMessage());
+ assertEquals(50, counts.get(0), 1e-1,getFailureMessage(backend));
+ assertEquals(50, counts.get(1), 1e-1,getFailureMessage(backend));
+ assertEquals(50, counts.get(2), 1e-1,getFailureMessage(backend));
}
@@ -694,14 +696,14 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
INDArray expLabels3d = Nd4j.create(3, 3, 4);
expLabels3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
- l3d1);
+ l3d1);
expLabels3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
- NDArrayIndex.interval(0, 3)}, l3d2);
+ NDArrayIndex.interval(0, 3)}, l3d2);
INDArray expLM3d = Nd4j.create(3, 3, 4);
expLM3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
- lm3d1);
+ lm3d1);
expLM3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
- NDArrayIndex.interval(0, 3)}, lm3d2);
+ NDArrayIndex.interval(0, 3)}, lm3d2);
DataSet merged3d = DataSet.merge(Arrays.asList(ds3d1, ds3d2));
@@ -752,52 +754,52 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testShuffleNd(Nd4jBackend backend) {
- int numDims = 7;
- int nLabels = 3;
- Random r = new Random();
+ int numDims = 7;
+ int nLabels = 3;
+ Random r = new Random();
- int[] shape = new int[numDims];
- int entries = 1;
- for (int i = 0; i < numDims; i++) {
- //randomly generating shapes bigger than 1
- shape[i] = r.nextInt(4) + 2;
- entries *= shape[i];
- }
- int labels = shape[0] * nLabels;
+ int[] shape = new int[numDims];
+ int entries = 1;
+ for (int i = 0; i < numDims; i++) {
+ //randomly generating shapes bigger than 1
+ shape[i] = r.nextInt(4) + 2;
+ entries *= shape[i];
+ }
+ int labels = shape[0] * nLabels;
- INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
- INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
+ INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
+ INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
- DataSet ds = new DataSet(ds_data, ds_labels);
- ds.shuffle();
+ DataSet ds = new DataSet(ds_data, ds_labels);
+ ds.shuffle();
- //Checking Nd dataset which is the data
- for (int dim = 1; dim < numDims; dim++) {
- //get tensor along dimension - the order in every dimension but zero should be preserved
- for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
- //the difference between consecutive elements should be equal to the stride
- for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
- int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
- int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
- int f_element_diff = f_next_element - f_element;
- assertEquals(f_element_diff, ds_data.stride(dim));
- }
- }
- }
-
- //Checking 2d, features
- int dim = 1;
+ //Checking Nd dataset which is the data
+ for (int dim = 1; dim < numDims; dim++) {
//get tensor along dimension - the order in every dimension but zero should be preserved
- for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
+ for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
//the difference between consecutive elements should be equal to the stride
- for (int i = 0, j = 1; j < nLabels; i++, j++) {
- int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
- int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
- int l_element_diff = l_next_element - l_element;
- assertEquals(l_element_diff, ds_labels.stride(dim));
+ for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
+ int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
+ int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
+ int f_element_diff = f_next_element - f_element;
+ assertEquals(f_element_diff, ds_data.stride(dim));
}
}
+ }
+
+ //Checking 2d, features
+ int dim = 1;
+ //get tensor along dimension - the order in every dimension but zero should be preserved
+ for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
+ //the difference between consecutive elements should be equal to the stride
+ for (int i = 0, j = 1; j < nLabels; i++, j++) {
+ int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
+ int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
+ int l_element_diff = l_next_element - l_element;
+ assertEquals(l_element_diff, ds_labels.stride(dim));
+ }
+ }
}
@ParameterizedTest
@@ -936,9 +938,9 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
//Checking if the features and labels are equal
assertEquals(iDataSet.getFeatures(),
- dsList.get(i).getFeatures().get(all(), all(), interval(0, minTSLength + i)));
+ dsList.get(i).getFeatures().get(all(), all(), interval(0, minTSLength + i)));
assertEquals(iDataSet.getLabels(),
- dsList.get(i).getLabels().get(all(), all(), interval(0, minTSLength + i)));
+ dsList.get(i).getLabels().get(all(), all(), interval(0, minTSLength + i)));
}
}
@@ -964,8 +966,8 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
for (boolean lMask : b) {
DataSet ds = new DataSet((features ? f : null),
- (labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null),
- (lMask ? lm : null));
+ (labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null),
+ (lMask ? lm : null));
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
@@ -1009,7 +1011,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
boolean lMask = true;
DataSet ds = new DataSet((features ? f : null), (labels ? (labelsSameAsFeatures ? f : l) : null),
- (fMask ? fm : null), (lMask ? lm : null));
+ (fMask ? fm : null), (lMask ? lm : null));
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
@@ -1098,7 +1100,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
+ public void testDataSetMetaDataSerialization(Nd4jBackend backend) throws IOException {
for(boolean withMeta : new boolean[]{false, true}) {
// create simple data set with meta data object
@@ -1129,7 +1131,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testMultiDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend nd4jBackend) throws IOException {
+ public void testMultiDataSetMetaDataSerialization(Nd4jBackend nd4jBackend) throws IOException {
for(boolean withMeta : new boolean[]{false, true}) {
// create simple data set with meta data object
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java
index 152466d7d..beef4223d 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/KFoldIteratorTest.java
@@ -106,7 +106,8 @@ public class KFoldIteratorTest extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void checkCornerCaseException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1),
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java
index 4b4196e98..5a4873203 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIteratorTest.java
@@ -21,27 +21,25 @@
package org.nd4j.linalg.dataset;
-import org.junit.jupiter.api.Test;
-
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
-
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4jBackend;
import java.nio.file.Path;
import static org.junit.jupiter.api.Assertions.assertEquals;
-
public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTestWithBackends {
+ @TempDir Path testDir;
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testMiniBatches(@TempDir Path testDir) throws Exception {
+ public void testMiniBatches(Nd4jBackend backend) throws Exception {
DataSet load = new IrisDataSetIterator(150, 150).next();
final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.toFile());
while (iter.hasNext())
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java
index d720d815c..5d5765ac8 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java
@@ -39,8 +39,7 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
return 'c';
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_preConditionsIsNull_expect_NullPointerException(Nd4jBackend backend) {
assertThrows(NullPointerException.class,() -> {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java
index 923a8f7ee..28377da43 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java
@@ -41,8 +41,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
return 'c';
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_originalHeightIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@@ -51,8 +50,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_originalWidthIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@@ -61,8 +59,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_yStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@@ -71,8 +68,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_xStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@@ -81,8 +77,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@@ -91,8 +86,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@@ -101,8 +95,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@@ -111,8 +104,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
// Assemble
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java
index 4cb743883..a3155734c 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java
@@ -39,7 +39,8 @@ public class PermuteDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
return 'c';
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
assertThrows(NullPointerException.class,() -> {
// Assemble
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java
index 071bcfb85..b56220c7e 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java
@@ -20,7 +20,6 @@
package org.nd4j.linalg.dataset.api.preprocessor;
-import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
@@ -39,7 +38,8 @@ public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTestWithBacke
return 'c';
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
assertThrows(NullPointerException.class,() -> {
// Assemble
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java
index 518bd19ca..cad8f7eda 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/Nd4jTest.java
@@ -139,7 +139,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
INDArray actualResult = data.mean(0);
INDArray expectedResult = Nd4j.create(new double[] {3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6.,
6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6.}, new int[] {2, 4, 4});
- assertEquals(expectedResult, actualResult,getFailureMessage());
+ assertEquals(expectedResult, actualResult,getFailureMessage(backend));
}
@@ -154,7 +154,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
INDArray actualResult = data.var(false, 0);
INDArray expectedResult = Nd4j.create(new double[] {1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4.,
4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4.}, new long[] {2, 4, 4});
- assertEquals(expectedResult, actualResult,getFailureMessage());
+ assertEquals(expectedResult, actualResult,getFailureMessage(backend));
}
@ParameterizedTest
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java
index 6ce604bb5..239e43839 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/memory/CloseableTests.java
@@ -83,8 +83,7 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
}
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAccessException_1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
@@ -96,8 +95,7 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAccessException_2(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java
index d4f3058ff..5b5c46915 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java
@@ -384,7 +384,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
assertEquals(exp, arrayZ);
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+
public void testTypesValidation_1(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG);
@@ -397,7 +399,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTypesValidation_2(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
@@ -412,7 +415,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTypesValidation_3(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
@@ -422,6 +426,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
}
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTypesValidation_4(Nd4jBackend backend) {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.DOUBLE);
@@ -485,7 +491,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testBoolFloatCast2(){
+ public void testBoolFloatCast2(Nd4jBackend backend){
val first = Nd4j.zeros(DataType.FLOAT, 3, 5000);
INDArray asBool = first.castTo(DataType.BOOL);
INDArray not = Transforms.not(asBool); //
@@ -516,7 +522,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testAssignScalarSimple(){
+ public void testAssignScalarSimple(Nd4jBackend backend){
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
INDArray arr = Nd4j.scalar(dt, 10.0);
arr.assign(2.0);
@@ -526,7 +532,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testSimple(){
+ public void testSimple(Nd4jBackend backend){
Nd4j.create(1);
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) {
// System.out.println("----- " + dt + " -----");
@@ -551,7 +557,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testWorkspaceBool(){
+ public void testWorkspaceBool(Nd4jBackend backend){
val conf = WorkspaceConfiguration.builder().minSize(10 * 1024 * 1024)
.overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE)
.policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL)
@@ -559,7 +565,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS");
- for( int i=0; i<10; i++ ) {
+ for( int i = 0; i < 10; i++ ) {
try (val workspace = (Nd4jWorkspace)ws.notifyScopeEntered() ) {
val bool = Nd4j.create(DataType.BOOL, 1, 10);
val dbl = Nd4j.create(DataType.DOUBLE, 1, 10);
@@ -574,8 +580,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
}
}
- @Test
- @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ @Disabled
public void testArrayCreationFromPointer(Nd4jBackend backend) {
val source = Nd4j.create(new double[]{1, 2, 3, 4, 5});
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java
index 0717bd0d3..c09403f83 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/NativeBlasTests.java
@@ -40,13 +40,13 @@ public class NativeBlasTests extends BaseNd4jTestWithBackends {
@BeforeEach
- public void setUp(Nd4jBackend backend) {
+ public void setUp() {
Nd4j.getExecutioner().enableDebugMode(true);
Nd4j.getExecutioner().enableVerboseMode(true);
}
@AfterEach
- public void setDown(Nd4jBackend backend) {
+ public void setDown() {
Nd4j.getExecutioner().enableDebugMode(false);
Nd4j.getExecutioner().enableVerboseMode(false);
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java
index 3cb758a1b..d4fb22de8 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java
@@ -77,18 +77,18 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
double sim = Transforms.cosineSim(vec1, vec2);
- assertEquals( 1, sim, 1e-1,getFailureMessage());
+ assertEquals( 1, sim, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testCosineDistance(){
+ public void testCosineDistance(Nd4jBackend backend){
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3});
INDArray vec2 = Nd4j.create(new float[] {3, 5, 7});
// 1-17*sqrt(2/581)
double distance = Transforms.cosineDistance(vec1, vec2);
- assertEquals(0.0025851, distance, 1e-7,getFailureMessage());
+ assertEquals(0.0025851, distance, 1e-7,getFailureMessage(backend));
}
@ParameterizedTest
@@ -97,7 +97,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray arr = Nd4j.create(new double[] {55, 55});
INDArray arr2 = Nd4j.create(new double[] {60, 60});
double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).z().getDouble(0);
- assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage());
+ assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@@ -137,7 +137,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi();
INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6);
Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1));
- assertEquals(scalarMax, postMax,getFailureMessage());
+ assertEquals(scalarMax, postMax,getFailureMessage(backend));
}
@ParameterizedTest
@@ -147,14 +147,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1));
for (int i = 0; i < linspace.length(); i++) {
double val = linspace.getDouble(i);
- assertTrue( val >= 0 && val <= 1,getFailureMessage());
+ assertTrue( val >= 0 && val <= 1,getFailureMessage(backend));
}
INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4));
for (int i = 0; i < linspace2.length(); i++) {
double val = linspace2.getDouble(i);
- assertTrue( val >= 2 && val <= 4,getFailureMessage());
+ assertTrue( val >= 2 && val <= 4,getFailureMessage(backend));
}
}
@@ -163,7 +163,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
public void testNormMax(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).z().getDouble(0);
- assertEquals(4, normMax, 1e-1,getFailureMessage());
+ assertEquals(4, normMax, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@@ -187,7 +187,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
public void testNorm2(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).z().getDouble(0);
- assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage());
+ assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@@ -198,7 +198,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray xDup = x.dup();
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
- assertEquals(solution, x,getFailureMessage());
+ assertEquals(solution, x,getFailureMessage(backend));
}
@ParameterizedTest
@@ -221,13 +221,13 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray xDup = x.dup();
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
- assertEquals(solution, x,getFailureMessage());
+ assertEquals(solution, x,getFailureMessage(backend));
Sum acc = new Sum(x.dup());
opExecutioner.exec(acc);
- assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
Prod prod = new Prod(x.dup());
opExecutioner.exec(prod);
- assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
}
@@ -275,7 +275,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Variance variance = new Variance(x.dup(), true);
opExecutioner.exec(variance);
- assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
}
@@ -284,14 +284,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIamax(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
- assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage());
+ assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
}
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIamax2(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
- assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage());
+ assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
val op = new ArgAmax(linspace);
int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0);
@@ -307,11 +307,11 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Mean mean = new Mean(x);
opExecutioner.exec(mean);
- assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
Variance variance = new Variance(x.dup(), true);
opExecutioner.exec(variance);
- assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@@ -321,7 +321,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
val softMax = new SoftMax(arr);
opExecutioner.exec((CustomOp) softMax);
- assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
}
@@ -332,7 +332,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Pow pow = new Pow(oneThroughSix, 2);
Nd4j.getExecutioner().exec(pow);
INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36});
- assertEquals(answer, pow.z(),getFailureMessage());
+ assertEquals(answer, pow.z(),getFailureMessage(backend));
}
@@ -384,7 +384,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Log log = new Log(slice);
opExecutioner.exec(log);
INDArray assertion = Nd4j.create(new double[] {0., 1.09861229, 1.60943791});
- assertEquals(assertion, slice,getFailureMessage());
+ assertEquals(assertion, slice,getFailureMessage(backend));
}
@ParameterizedTest
@@ -572,7 +572,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
expected[i] = (float) Math.exp(slice.getDouble(i));
Exp exp = new Exp(slice);
opExecutioner.exec(exp);
- assertEquals( Nd4j.create(expected), slice,getFailureMessage());
+ assertEquals( Nd4j.create(expected), slice,getFailureMessage(backend));
}
@ParameterizedTest
@@ -582,7 +582,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
val softMax = new SoftMax(arr);
opExecutioner.exec((CustomOp) softMax);
- assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java
index 2f6e4d874..151db8db2 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java
@@ -84,7 +84,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
@AfterEach
- public void after(Nd4jBackend backend) {
+ public void after() {
Nd4j.setDataType(this.initialType);
}
@@ -140,17 +140,17 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
double sim = Transforms.cosineSim(vec1, vec2);
- assertEquals(1, sim, 1e-1,getFailureMessage());
+ assertEquals(1, sim, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testCosineDistance(){
+ public void testCosineDistance(Nd4jBackend backend){
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3});
INDArray vec2 = Nd4j.create(new float[] {3, 5, 7});
// 1-17*sqrt(2/581)
double distance = Transforms.cosineDistance(vec1, vec2);
- assertEquals( 0.0025851, distance, 1e-7,getFailureMessage());
+ assertEquals( 0.0025851, distance, 1e-7,getFailureMessage(backend));
}
@ParameterizedTest
@@ -179,7 +179,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
INDArray arr2 = Nd4j.create(new double[] {60, 60});
double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).getFinalResult()
.doubleValue();
- assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage());
+ assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@@ -188,7 +188,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi();
INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6);
Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1));
- assertEquals(postMax, scalarMax,getFailureMessage());
+ assertEquals(postMax, scalarMax,getFailureMessage(backend));
}
@ParameterizedTest
@@ -198,14 +198,14 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1));
for (int i = 0; i < linspace.length(); i++) {
double val = linspace.getDouble(i);
- assertTrue( val >= 0 && val <= 1,getFailureMessage());
+ assertTrue( val >= 0 && val <= 1,getFailureMessage(backend));
}
INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4));
for (int i = 0; i < linspace2.length(); i++) {
double val = linspace2.getDouble(i);
- assertTrue(val >= 2 && val <= 4,getFailureMessage());
+ assertTrue(val >= 2 && val <= 4,getFailureMessage(backend));
}
}
@@ -215,7 +215,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
public void testNormMax(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).getFinalResult().doubleValue();
- assertEquals(4, normMax, 1e-1,getFailureMessage());
+ assertEquals(4, normMax, 1e-1,getFailureMessage(backend));
}
@@ -224,7 +224,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
public void testNorm2(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).getFinalResult().doubleValue();
- assertEquals( 5.4772255750516612, norm2, 1e-1,getFailureMessage());
+ assertEquals( 5.4772255750516612, norm2, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@@ -235,7 +235,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
INDArray xDup = x.dup();
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
- assertEquals(solution, x,getFailureMessage());
+ assertEquals(solution, x,getFailureMessage(backend));
}
@ParameterizedTest
@@ -258,13 +258,13 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
INDArray xDup = x.dup();
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{ x}));
- assertEquals(solution, x,getFailureMessage());
+ assertEquals(solution, x,getFailureMessage(backend));
Sum acc = new Sum(x.dup());
opExecutioner.exec(acc);
- assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
Prod prod = new Prod(x.dup());
opExecutioner.exec(prod);
- assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
}
@@ -316,7 +316,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
Variance variance = new Variance(x.dup(), true);
opExecutioner.exec(variance);
- assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
}
@@ -328,11 +328,11 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
Mean mean = new Mean(x);
opExecutioner.exec(mean);
- assertEquals(3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals(3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
Variance variance = new Variance(x.dup(), true);
opExecutioner.exec(variance);
- assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@@ -342,7 +342,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
val softMax = new SoftMax(arr);
opExecutioner.exec((CustomOp) softMax);
- assertEquals( 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals( 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@@ -373,7 +373,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
Pow pow = new Pow(oneThroughSix, 2);
Nd4j.getExecutioner().exec(pow);
INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36});
- assertEquals(answer, pow.z(),getFailureMessage());
+ assertEquals(answer, pow.z(),getFailureMessage(backend));
}
@@ -427,7 +427,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
Log exp = new Log(slice);
opExecutioner.exec(exp);
INDArray assertion = Nd4j.create(new double[] {0.0, 0.6931471824645996, 1.0986123085021973});
- assertEquals(assertion, slice,getFailureMessage());
+ assertEquals(assertion, slice,getFailureMessage(backend));
}
@ParameterizedTest
@@ -441,7 +441,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
expected[i] = (float) Math.exp(slice.getDouble(i));
Exp exp = new Exp(slice);
opExecutioner.exec(exp);
- assertEquals(Nd4j.create(expected), slice,getFailureMessage());
+ assertEquals(Nd4j.create(expected), slice,getFailureMessage(backend));
}
@ParameterizedTest
@@ -451,7 +451,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
val softMax = new SoftMax(arr);
opExecutioner.exec(softMax);
- assertEquals( 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
+ assertEquals( 1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
val softmax = new SoftMax(linspace.dup());
@@ -467,7 +467,7 @@ public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
val max = new SoftMax(linspace);
Nd4j.getExecutioner().exec(max);
linspace.assign(max.outputArguments().get(0));
- assertEquals(linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1,getFailureMessage());
+ assertEquals(linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java
index b9c0f3cb8..e436cb76d 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/InfNanTests.java
@@ -50,7 +50,6 @@ public class InfNanTests extends BaseNd4jTestWithBackends {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED);
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInf1(Nd4jBackend backend) {
@@ -67,7 +66,6 @@ public class InfNanTests extends BaseNd4jTestWithBackends {
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInf2(Nd4jBackend backend) {
@@ -103,7 +101,6 @@ public class InfNanTests extends BaseNd4jTestWithBackends {
OpExecutionerUtil.checkForAny(x);
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNaN1(Nd4jBackend backend) {
@@ -120,7 +117,6 @@ public class InfNanTests extends BaseNd4jTestWithBackends {
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNaN2(Nd4jBackend backend) {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java
index f83334582..4aff0470d 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java
@@ -306,7 +306,8 @@ public class OperationProfilerTests extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNaNPanic1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC);
@@ -318,7 +319,8 @@ public class OperationProfilerTests extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNaNPanic2(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC);
@@ -330,7 +332,8 @@ public class OperationProfilerTests extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNaNPanic3(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC);
@@ -343,7 +346,8 @@ public class OperationProfilerTests extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScopePanic1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC);
@@ -362,7 +366,8 @@ public class OperationProfilerTests extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScopePanic2(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java
index 614007a9e..831046c1d 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/PerformanceTrackerTests.java
@@ -45,13 +45,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class PerformanceTrackerTests extends BaseNd4jTestWithBackends {
@BeforeEach
- public void setUp(Nd4jBackend backend) {
+ public void setUp() {
PerformanceTracker.getInstance().clear();
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.BANDWIDTH);
}
@AfterEach
- public void tearDown(Nd4jBackend backend) {
+ public void tearDown() {
PerformanceTracker.getInstance().clear();
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC);
}
@@ -109,7 +109,8 @@ public class PerformanceTrackerTests extends BaseNd4jTestWithBackends {
assertEquals(500, res);
}
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testTrackerCpu_1(Nd4jBackend backend) {
if (!Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("native"))
@@ -127,7 +128,8 @@ public class PerformanceTrackerTests extends BaseNd4jTestWithBackends {
assertTrue(bw > 0);
}
- @Test
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled("useless these days")
public void testTrackerGpu_1(Nd4jBackend backend) {
if (!Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda"))
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java
index 81ede4120..0f6630153 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/StackAggregatorTests.java
@@ -50,14 +50,14 @@ public class StackAggregatorTests extends BaseNd4jTestWithBackends {
}
@BeforeEach
- public void setUp(Nd4jBackend backend) {
+ public void setUp() {
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().stackTrace(true).build());
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL);
OpProfiler.getInstance().reset();
}
@AfterEach
- public void tearDown(Nd4jBackend backend) {
+ public void tearDown() {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED);
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java
index 5861372b6..c7f05bfc3 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java
@@ -26,6 +26,7 @@ import static org.junit.jupiter.api.Assertions.fail;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
+import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@@ -123,6 +124,7 @@ public class RngValidationTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ @Disabled
public void validateRngDistributions(Nd4jBackend backend){
List testCases = new ArrayList<>();
for(DataType type : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java
index b1c62bba4..c60dff2f2 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java
@@ -46,9 +46,11 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class NumpyFormatTests extends BaseNd4jTestWithBackends {
+ @TempDir Path testDir;
+
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testToNpyFormat(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testToNpyFormat(Nd4jBackend backend) throws Exception {
val dir = testDir.toFile();
new ClassPathResource("numpy_arrays/").copyDirectory(dir);
@@ -98,7 +100,7 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testToNpyFormatScalars(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testToNpyFormatScalars(Nd4jBackend backend) throws Exception {
// File dir = new File("C:\\DL4J\\Git\\dl4j-test-resources\\src\\main\\resources\\numpy_arrays\\scalar");
val dir = testDir.toFile();
@@ -153,7 +155,7 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testNpzReading(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testNpzReading(Nd4jBackend backend) throws Exception {
val dir = testDir.toFile();
new ClassPathResource("numpy_arrays/npz/").copyDirectory(dir);
@@ -214,7 +216,8 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testNpy(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ @Disabled
+ public void testNpy(Nd4jBackend backend) throws Exception {
for(boolean empty : new boolean[]{false, true}) {
val dir = testDir.toFile();
if(!empty) {
@@ -264,8 +267,9 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends {
assertEquals(Nd4j.scalar(DataType.INT, 1), out);
}
- @Test()
- public void readNumpyCorruptHeader1(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ public void readNumpyCorruptHeader1(Nd4jBackend backend) throws Exception {
assertThrows(RuntimeException.class,() -> {
File f = testDir.toFile();
@@ -288,8 +292,9 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends {
}
- @Test()
- public void readNumpyCorruptHeader2(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ public void readNumpyCorruptHeader2(Nd4jBackend backend) throws Exception {
assertThrows(RuntimeException.class,() -> {
File f = testDir.toFile();
@@ -312,7 +317,8 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAbsentNumpyFile_1(Nd4jBackend backend) throws Exception {
assertThrows(IllegalArgumentException.class,() -> {
val f = new File("pew-pew-zomg.some_extension_that_wont_exist");
@@ -321,7 +327,9 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends {
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ @Disabled
public void testAbsentNumpyFile_2(Nd4jBackend backend) throws Exception {
assertThrows(IllegalArgumentException.class,() -> {
val f = new File("c:/develop/batch-x-1.npy");
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java
index 3244b5d2e..a554f0954 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java
@@ -184,8 +184,7 @@ public class EmptyTests extends BaseNd4jTestWithBackends {
assertEquals(1, array.rank());
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptyWithShape_3(Nd4jBackend backend) {
@@ -255,7 +254,6 @@ public class EmptyTests extends BaseNd4jTestWithBackends {
assertEquals(e, reduced);
}
- @Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEmptyReduction_4(Nd4jBackend backend) {
@@ -342,7 +340,6 @@ public class EmptyTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
-
public void testEmptyNoop(Nd4jBackend backend) {
val output = Nd4j.empty(DataType.LONG);
@@ -355,7 +352,6 @@ public class EmptyTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
-
public void testEmptyConstructor_1(Nd4jBackend backend) {
val x = Nd4j.create(new double[0]);
assertTrue(x.isEmpty());
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java
index b159acdb4..5c8495f82 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java
@@ -45,7 +45,7 @@ public class ShapeTestsC extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
@AfterEach
- public void after(Nd4jBackend backend) {
+ public void after() {
Nd4j.setDataType(this.initialType);
}
@@ -277,7 +277,7 @@ public class ShapeTestsC extends BaseNd4jTestWithBackends {
INDArray twoByThree = Nd4j.linspace(1, 600, 600, DataType.FLOAT).reshape(150, 4);
INDArray columnVar = twoByThree.sum(0);
INDArray assertion = Nd4j.create(new float[] {44850.0f, 45000.0f, 45150.0f, 45300.0f});
- assertEquals(assertion, columnVar,getFailureMessage());
+ assertEquals(assertion, columnVar,getFailureMessage(backend));
}
@@ -287,7 +287,7 @@ public class ShapeTestsC extends BaseNd4jTestWithBackends {
INDArray twoByThree = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray rowMean = twoByThree.mean(1);
INDArray assertion = Nd4j.create(new double[] {1.5, 3.5});
- assertEquals(assertion, rowMean,getFailureMessage());
+ assertEquals(assertion, rowMean,getFailureMessage(backend));
}
@@ -298,7 +298,7 @@ public class ShapeTestsC extends BaseNd4jTestWithBackends {
INDArray twoByThree = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray rowStd = twoByThree.std(1);
INDArray assertion = Nd4j.create(new double[] {0.7071067811865476f, 0.7071067811865476f});
- assertEquals(assertion, rowStd,getFailureMessage());
+ assertEquals(assertion, rowStd,getFailureMessage(backend));
}
@@ -311,7 +311,7 @@ public class ShapeTestsC extends BaseNd4jTestWithBackends {
INDArray twoByThree = Nd4j.linspace(1, 600, 600, DataType.DOUBLE).reshape(150, 4);
INDArray columnVar = twoByThree.sum(0);
INDArray assertion = Nd4j.create(new double[] {44850.0f, 45000.0f, 45150.0f, 45300.0f});
- assertEquals(assertion, columnVar,getFailureMessage());
+ assertEquals(assertion, columnVar,getFailureMessage(backend));
DataTypeUtil.setDTypeForContext(initialType);
}
@@ -333,14 +333,14 @@ public class ShapeTestsC extends BaseNd4jTestWithBackends {
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {1, 4});
INDArray cumSumAnswer = Nd4j.create(new double[] {1, 3, 6, 10}, new long[] {1, 4});
INDArray cumSumTest = n.cumsum(0);
- assertEquals( cumSumAnswer, cumSumTest,getFailureMessage());
+ assertEquals( cumSumAnswer, cumSumTest,getFailureMessage(backend));
INDArray n2 = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 3, 2);
INDArray axis0assertion = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0,
18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0}, n2.shape());
INDArray axis0Test = n2.cumsum(0);
- assertEquals(axis0assertion, axis0Test,getFailureMessage());
+ assertEquals(axis0assertion, axis0Test,getFailureMessage(backend));
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java
index bb6593b6f..fd0ddb762 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java
@@ -223,8 +223,7 @@ public class ConcatTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp, concat2);
}
- @Test()
- @ParameterizedTest
+ @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConcatVector(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java
index df8704477..aa997a3c9 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java
@@ -55,7 +55,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
INDArray sub = nd.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2));
Nd4j.getExecutioner().exec(new ScalarAdd(sub, 2));
- assertEquals(Nd4j.create(new double[][] {{3, 4}, {6, 7}}), sub,getFailureMessage());
+ assertEquals(Nd4j.create(new double[][] {{3, 4}, {6, 7}}), sub,getFailureMessage(backend));
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java
index 235e04c72..58d665a79 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java
@@ -48,12 +48,12 @@ public class RavelIndexTest extends BaseNd4jTestWithBackends {
@BeforeEach
- public void setUp(Nd4jBackend backend) {
+ public void setUp() {
Nd4j.setDataType(DataType.FLOAT);
}
@AfterEach
- public void setDown(Nd4jBackend backend) {
+ public void setDown() {
Nd4j.setDataType(initialType);
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java
index 3811539d3..c516a55f7 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java
@@ -53,12 +53,12 @@ public class SortCooTests extends BaseNd4jTestWithBackends {
@BeforeEach
- public void setUp(Nd4jBackend backend) {
+ public void setUp() {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
}
@AfterEach
- public void setDown(Nd4jBackend backend) {
+ public void setDown() {
Nd4j.setDefaultDataTypes(initialType, Nd4j.defaultFloatingPointType());
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java
index eaea0b5c1..9ffd78d90 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/DataSetUtilsTest.java
@@ -42,6 +42,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
public class DataSetUtilsTest extends BaseNd4jTestWithBackends {
+ @TempDir Path tmpFld;
@Override
public char ordering(){
@@ -53,10 +54,9 @@ public class DataSetUtilsTest extends BaseNd4jTestWithBackends {
//
private SIS sis;
//
- @Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testAll(@TempDir Path tmpFld,Nd4jBackend backend) {
+ public void testAll(Nd4jBackend backend) {
//
sis = new SIS();
//
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java
index 4866e5c3e..5c71fb1f7 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ShapeTestC.java
@@ -195,7 +195,8 @@ public class ShapeTestC extends BaseNd4jTestWithBackends {
assertArrayEquals(exp, norm);
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAxisNormalization_3(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
val axis = new int[] {1, -2, 2};
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java
index 9b17b0d6e..92483350a 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/util/ValidationUtilTests.java
@@ -51,9 +51,11 @@ import static org.junit.jupiter.api.Assertions.*;
public class ValidationUtilTests extends BaseNd4jTestWithBackends {
+ @TempDir Path testDir;
+
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testFileValidation(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testFileValidation(Nd4jBackend backend) throws Exception {
File f = testDir.toFile();
//Test not existent file:
@@ -90,7 +92,7 @@ public class ValidationUtilTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testZipValidation(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testZipValidation(Nd4jBackend backend) throws Exception {
File f = testDir.toFile();
//Test not existent file:
@@ -141,7 +143,7 @@ public class ValidationUtilTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testINDArrayTextValidation(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testINDArrayTextValidation(Nd4jBackend backend) throws Exception {
File f = testDir.toFile();
//Test not existent file:
@@ -187,7 +189,7 @@ public class ValidationUtilTests extends BaseNd4jTestWithBackends {
INDArray arr = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4);
Nd4j.writeTxt(arr, fValid.getPath());
byte[] indarrayTxtBytes = FileUtils.readFileToByteArray(fValid);
- for( int i=0; i<30; i++ ){
+ for( int i = 0; i < 30; i++) {
indarrayTxtBytes[i] = (byte)('a' + i);
}
File fCorrupt = new File(f, "corrupt.txt");
@@ -210,11 +212,9 @@ public class ValidationUtilTests extends BaseNd4jTestWithBackends {
// System.out.println(vr4.toString());
}
-
- @Test
- @Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
- public void testNpyValidation(@TempDir Path testDir) throws Exception {
-
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ public void testNpyValidation(Nd4jBackend backend) throws Exception {
File f = testDir.toFile();
//Test not existent file:
@@ -283,9 +283,9 @@ public class ValidationUtilTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testNpzValidation(@TempDir Path testDIr,Nd4jBackend backend) throws Exception {
+ public void testNpzValidation(Nd4jBackend backend) throws Exception {
- File f = testDIr.toFile();
+ File f = testDir.toFile();
//Test not existent file:
File fNonExistent = new File("doesntExist.npz");
@@ -328,7 +328,7 @@ public class ValidationUtilTests extends BaseNd4jTestWithBackends {
//Test corrupted npz format:
File fValid = new ClassPathResource("numpy_arrays/npz/float32.npz").getFile();
byte[] numpyBytes = FileUtils.readFileToByteArray(fValid);
- for( int i=0; i<30; i++ ){
+ for( int i = 0; i < 30; i++) {
numpyBytes[i] = 0;
}
File fCorrupt = new File(f, "corrupt.npz");
@@ -353,7 +353,7 @@ public class ValidationUtilTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testNumpyTxtValidation(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testNumpyTxtValidation(Nd4jBackend backend) throws Exception {
File f = testDir.toFile();
//Test not existent file:
@@ -422,7 +422,7 @@ public class ValidationUtilTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
- public void testValidateSameDiff(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
+ public void testValidateSameDiff(Nd4jBackend backend) throws Exception {
Nd4j.setDataType(DataType.FLOAT);
File f = testDir.toFile();
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java
index 61840e0d0..cb846d95b 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java
@@ -23,6 +23,7 @@ package org.nd4j.linalg.workspace;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@@ -54,7 +55,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
private DataType initialType = Nd4j.dataType();
@AfterEach
- public void shutUp(Nd4jBackend backend) {
+ public void shutUp() {
Nd4j.getMemoryManager().setCurrentWorkspace(null);
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
Nd4j.setDataType(this.initialType);
@@ -62,6 +63,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ @Disabled
public void testVariableTimeSeries1(Nd4jBackend backend) {
WorkspaceConfiguration configuration = WorkspaceConfiguration
.builder()
@@ -170,6 +172,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
+ @Disabled
public void testVariableTimeSeries2(Nd4jBackend backend) {
WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(3.0)
.policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE)
@@ -247,7 +250,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "WS132143452343");
- for( int j=0; j<100; j++ ){
+ for( int j = 0; j < 100; j++) {
try(MemoryWorkspace ws = workspace.notifyScopeEntered()) {
@@ -409,7 +412,8 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
Files.delete(tmpFile);
}
- @Test()
+ @ParameterizedTest
+ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDeleteMappedFile_2() throws Exception {
assertThrows(IllegalArgumentException.class,() -> {
if (!Nd4j.getEnvironment().isCPU())
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java
index 0145589e3..e68caef4b 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java
@@ -112,7 +112,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
@AfterEach
- public void shutUp(Nd4jBackend backend) {
+ public void shutUp() {
Nd4j.getMemoryManager().setCurrentWorkspace(null);
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
Nd4j.setDataType(this.initialType);
diff --git a/nd4j/nd4j-backends/nd4j-tests/variables-added-old.txt b/nd4j/nd4j-backends/nd4j-tests/variables-added-old.txt
new file mode 100644
index 000000000..bed880a64
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-tests/variables-added-old.txt
@@ -0,0 +1,18 @@
+in_0/read,in_0/read
+while/Enter,while/Enter
+while/Enter_1,while/Enter_1
+while/Merge,while/Merge
+while/Merge_1,while/Merge_1
+while/Less,while/Less
+while/LoopCond,while/LoopCond
+while/Switch,while/Switch
+while/Switch:1,while/Switch
+while/Switch_1,while/Switch_1
+while/Switch_1:1,while/Switch_1
+while/Identity,while/Identity
+while/Exit,while/Exit
+while/Identity_1,while/Identity_1
+while/Exit_1,while/Exit_1
+while/add,while/add
+while/NextIteration_1,while/NextIteration_1
+while/NextIteration,while/NextIteration
diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/linalg/BaseNd4jTestWithBackends.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/linalg/BaseNd4jTestWithBackends.java
index 1758ac8ec..44bd24556 100644
--- a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/linalg/BaseNd4jTestWithBackends.java
+++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/linalg/BaseNd4jTestWithBackends.java
@@ -53,8 +53,6 @@ public abstract class BaseNd4jTestWithBackends extends BaseND4JTest {
}
}
- protected Nd4jBackend backend;
- protected String name;
public final static String DEFAULT_BACKEND = "org.nd4j.linalg.defaultbackend";
@@ -95,7 +93,7 @@ public abstract class BaseNd4jTestWithBackends extends BaseND4JTest {
return 'c';
}
- public String getFailureMessage() {
+ public String getFailureMessage(Nd4jBackend backend) {
return "Failed with backend " + backend.getClass().getName() + " and ordering " + ordering();
}
}
diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml
index ab0fa3096..505ea85a0 100644
--- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml
+++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/pom.xml
@@ -85,60 +85,5 @@
nd4j-testresources
-
- nd4j-tests-cpu
-
- false
-
-
-
- org.nd4j
- nd4j-native
- ${project.version}
-
-
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
-
- src/test/java
-
- *.java
- **/*.java
-
- -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"
-
-
-
-
-
-
-
- nd4j-tests-cuda
-
- false
-
-
-
- org.nd4j
- nd4j-cuda-11.0
- ${project.version}
-
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
-
- true
-
-
-
-
-
diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml
index 6b0de214f..4eb2a05e2 100644
--- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml
+++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml
@@ -111,7 +111,7 @@
*.java
**/*.java
- "
+
diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml
index d24533025..147366e5e 100644
--- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml
+++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/pom.xml
@@ -103,7 +103,7 @@
*.java
**/*.java
- "
+
@@ -126,10 +126,6 @@
org.apache.maven.plugins
maven-surefire-plugin
-
- -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"
-
-
diff --git a/nd4j/nd4j-serde/nd4j-aeron/pom.xml b/nd4j/nd4j-serde/nd4j-aeron/pom.xml
index 68a75125b..8b86b6a9e 100644
--- a/nd4j/nd4j-serde/nd4j-aeron/pom.xml
+++ b/nd4j/nd4j-serde/nd4j-aeron/pom.xml
@@ -73,117 +73,5 @@
testresources
-
- nd4j-tests-cpu
-
- false
-
-
-
- org.nd4j
- nd4j-native
- ${project.version}
-
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
- true
-
-
- org.nd4j
- nd4j-native
- ${project.version}
-
-
-
-
-
-
- src/test/java
-
- *.java
- **/*.java
- **/Test*.java
- **/*Test.java
- **/*TestCase.java
-
- org.junit.jupiter:junit-jupiter
-
-
- org.nd4j.linalg.cpu.nativecpu.CpuBackend
-
-
- org.nd4j.linalg.cpu.nativecpu.CpuBackend
-
-
-
- "
-
-
-
-
-
-
- nd4j-tests-cuda
-
- false
-
-
-
- org.nd4j
- nd4j-cuda-11.0
- ${project.version}
-
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
-
-
- org.apache.maven.surefire
- surefire-junit47
- 2.19.1
-
-
-
-
-
- src/test/java
-
- *.java
- **/*.java
- **/Test*.java
- **/*Test.java
- **/*TestCase.java
-
- org.junit.jupiter:junit-jupiter
-
-
- org.nd4j.linalg.jcublas.JCublasBackend
-
-
- org.nd4j.linalg.jcublas.JCublasBackend
-
-
-
- -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"
-
-
-
-
-
diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/nd4j/nd4j-serde/nd4j-arrow/pom.xml
index e3e4d3439..89ddb39ee 100644
--- a/nd4j/nd4j-serde/nd4j-arrow/pom.xml
+++ b/nd4j/nd4j-serde/nd4j-arrow/pom.xml
@@ -57,114 +57,5 @@
testresources
-
- nd4j-tests-cpu
-
- false
-
-
-
- org.nd4j
- nd4j-native
- ${project.version}
-
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
-
-
-
- ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/
-
-
- src/test/java
-
- *.java
- **/*.java
- **/Test*.java
- **/*Test.java
- **/*TestCase.java
-
- org.junit.jupiter:junit-jupiter
-
-
- org.nd4j.linalg.cpu.nativecpu.CpuBackend
-
-
- org.nd4j.linalg.cpu.nativecpu.CpuBackend
-
-
-
- -Dfile.encoding=UTF-8 "
-
-
-
-
-
-
- nd4j-tests-cuda
-
- false
-
-
-
- org.nd4j
- nd4j-cuda-11.0
- ${project.version}
-
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
-
-
- org.apache.maven.surefire
- surefire-junit47
- 2.19.1
-
-
-
-
-
- ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/
-
-
- src/test/java
-
- *.java
- **/*.java
- **/Test*.java
- **/*Test.java
- **/*TestCase.java
-
- org.junit.jupiter:junit-jupiter
-
-
- org.nd4j.linalg.jcublas.JCublasBackend
-
-
- org.nd4j.linalg.jcublas.JCublasBackend
-
-
-
- -Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"
-
-
-
-
-
diff --git a/nd4j/nd4j-serde/nd4j-kryo/pom.xml b/nd4j/nd4j-serde/nd4j-kryo/pom.xml
index 4298f3016..e32c887e3 100644
--- a/nd4j/nd4j-serde/nd4j-kryo/pom.xml
+++ b/nd4j/nd4j-serde/nd4j-kryo/pom.xml
@@ -113,114 +113,5 @@
testresources
-
- nd4j-tests-cpu
-
- false
-
-
-
- org.nd4j
- nd4j-native
- ${project.version}
-
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
-
-
-
- ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/
-
-
- src/test/java
-
- *.java
- **/*.java
- **/Test*.java
- **/*Test.java
- **/*TestCase.java
-
- org.junit.jupiter:junit-jupiter
-
-
- org.nd4j.linalg.cpu.nativecpu.CpuBackend
-
-
- org.nd4j.linalg.cpu.nativecpu.CpuBackend
-
-
-
- "
-
-
-
-
-
-
- nd4j-tests-cuda
-
- false
-
-
-
- org.nd4j
- nd4j-cuda-11.0
- ${project.version}
-
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
-
-
- org.apache.maven.surefire
- surefire-junit47
- 2.19.1
-
-
-
-
-
- ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/
-
-
- src/test/java
-
- *.java
- **/*.java
- **/Test*.java
- **/*Test.java
- **/*TestCase.java
-
- org.junit.jupiter:junit-jupiter
-
-
- org.nd4j.linalg.jcublas.JCublasBackend
-
-
- org.nd4j.linalg.jcublas.JCublasBackend
-
-
-
- -Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"
-
-
-
-
-
diff --git a/pom.xml b/pom.xml
index 791d7a4bc..5a05d9b97 100644
--- a/pom.xml
+++ b/pom.xml
@@ -319,8 +319,7 @@
0.9.1
1.0.0
2.2.0
- 1.4.30
- 1.3
+ 1.4.31
@@ -473,6 +472,15 @@
${maven-surefire-plugin.version}
+
+
+ org.junit:junit
+ com.google.android:android
+
+
+ true
+ false
+
org.jetbrains.kotlin
@@ -491,12 +499,12 @@
org.jetbrains.kotlin
kotlin-maven-allopen
- 1.4.30-M1
+ ${kotlin.version}
org.jetbrains.kotlin
kotlin-maven-noarg
- 1.4.30-M1
+ ${kotlin.version}
diff --git a/python4j/python4j-numpy/pom.xml b/python4j/python4j-numpy/pom.xml
index 09cb57553..cf321494d 100644
--- a/python4j/python4j-numpy/pom.xml
+++ b/python4j/python4j-numpy/pom.xml
@@ -20,8 +20,8 @@
-->
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
4.0.0
@@ -56,122 +56,4 @@
1.0.0-SNAPSHOT
-
-
-
- test-nd4j-native
-
-
- org.nd4j
- nd4j-native
- ${nd4j.version}
- test
-
-
- org.deeplearning4j
- dl4j-test-resources
- ${nd4j.version}
- test
-
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
- true
-
-
- org.nd4j
- nd4j-native
- ${project.version}
-
-
-
-
-
-
- src/test/java
-
- *.java
- **/*.java
- **/Test*.java
- **/*Test.java
- **/*TestCase.java
-
- org.junit.jupiter:junit-jupiter
-
-
- org.nd4j.linalg.cpu.nativecpu.CpuBackend
-
-
- org.nd4j.linalg.cpu.nativecpu.CpuBackend
-
-
-
- "
-
-
-
-
-
-
-
- test-nd4j-cuda-11.0
-
-
- org.nd4j
- nd4j-cuda-11.0
- ${nd4j.version}
- test
-
-
- org.deeplearning4j
- dl4j-test-resources
- ${nd4j.version}
- test
-
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
- true
-
-
-
- src/test/java
-
- *.java
- **/*.java
- **/Test*.java
- **/*Test.java
- **/*TestCase.java
-
- org.junit.jupiter:junit-jupiter
-
-
- org.nd4j.linalg.jcublas.JCublasBackend
-
-
- org.nd4j.linalg.jcublas.JCublasBackend
-
-
-
- -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"
-
-
-
-
-
-
diff --git a/rl4j/pom.xml b/rl4j/pom.xml
index 3c3d247ea..b0eae38ca 100644
--- a/rl4j/pom.xml
+++ b/rl4j/pom.xml
@@ -90,7 +90,7 @@
${skipBackendChoice}
- test-nd4j-native,test-nd4j-cuda-11.0
+ nd4j-tests-cpu,nd4j-tests-cuda
false
@@ -99,24 +99,6 @@
-
- maven-surefire-plugin
- true
-
- -Ddtype=double -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"
-
-
- true
- false
-
-
com.lewisd
lint-maven-plugin
@@ -180,7 +162,7 @@
- test-nd4j-native
+ nd4j-tests-cpu
org.nd4j
@@ -191,7 +173,7 @@
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
org.nd4j
diff --git a/rl4j/rl4j-ale/pom.xml b/rl4j/rl4j-ale/pom.xml
index a07325886..bbfe9dbc6 100644
--- a/rl4j/rl4j-ale/pom.xml
+++ b/rl4j/rl4j-ale/pom.xml
@@ -50,10 +50,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/rl4j/rl4j-api/pom.xml b/rl4j/rl4j-api/pom.xml
index 2d1b34a4c..731617137 100644
--- a/rl4j/rl4j-api/pom.xml
+++ b/rl4j/rl4j-api/pom.xml
@@ -45,10 +45,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml
index eb63be1c8..f1d056cd2 100644
--- a/rl4j/rl4j-core/pom.xml
+++ b/rl4j/rl4j-core/pom.xml
@@ -138,10 +138,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/rl4j/rl4j-doom/pom.xml b/rl4j/rl4j-doom/pom.xml
index 367267336..1ac2939d0 100644
--- a/rl4j/rl4j-doom/pom.xml
+++ b/rl4j/rl4j-doom/pom.xml
@@ -45,10 +45,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/rl4j/rl4j-gym/pom.xml b/rl4j/rl4j-gym/pom.xml
index 250f0cb97..180237718 100644
--- a/rl4j/rl4j-gym/pom.xml
+++ b/rl4j/rl4j-gym/pom.xml
@@ -51,10 +51,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda
diff --git a/rl4j/rl4j-malmo/pom.xml b/rl4j/rl4j-malmo/pom.xml
index 821cf99f2..213bef813 100644
--- a/rl4j/rl4j-malmo/pom.xml
+++ b/rl4j/rl4j-malmo/pom.xml
@@ -57,10 +57,10 @@
- test-nd4j-native
+ nd4j-tests-cpu
- test-nd4j-cuda-11.0
+ nd4j-tests-cuda