Unify nd4j test profiles, get rid of old modules, fix more parameter issues with junit 5 tests

master
agibsonccc 2021-03-18 10:58:50 +09:00
parent e0077c38a9
commit ad4f47096c
135 changed files with 789 additions and 1911 deletions

View File

@ -31,7 +31,7 @@ jobs:
protoc --version protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd .. cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
windows-x86_64: windows-x86_64:
runs-on: windows-2019 runs-on: windows-2019
@ -44,7 +44,7 @@ jobs:
run: | run: |
set "PATH=C:\msys64\usr\bin;%PATH%" set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
@ -60,5 +60,5 @@ jobs:
run: | run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test

View File

@ -31,7 +31,7 @@ jobs:
protoc --version protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd .. cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
windows-x86_64: windows-x86_64:
runs-on: windows-2019 runs-on: windows-2019
@ -44,7 +44,7 @@ jobs:
run: | run: |
set "PATH=C:\msys64\usr\bin;%PATH%" set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
@ -60,5 +60,5 @@ jobs:
run: | run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test

View File

@ -34,5 +34,5 @@ jobs:
cmake --version cmake --version
protoc --version protoc --version
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptest-nd4j-native --also-make clean test mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Pnd4j-tests-cpu --also-make clean test

View File

@ -109,10 +109,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -30,6 +30,8 @@ import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.loader.FileBatch; import org.nd4j.common.loader.FileBatch;
import java.io.File; import java.io.File;
@ -40,13 +42,16 @@ import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.linalg.factory.Nd4jBackend;
@DisplayName("File Batch Record Reader Test") @DisplayName("File Batch Record Reader Test")
class FileBatchRecordReaderTest extends BaseND4JTest { public class FileBatchRecordReaderTest extends BaseND4JTest {
@TempDir Path testDir;
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Csv") @DisplayName("Test Csv")
void testCsv(@TempDir Path testDir) throws Exception { void testCsv(Nd4jBackend backend) throws Exception {
// This is an unrealistic use case - one line/record per CSV // This is an unrealistic use case - one line/record per CSV
File baseDir = testDir.toFile(); File baseDir = testDir.toFile();
List<File> fileList = new ArrayList<>(); List<File> fileList = new ArrayList<>();
@ -75,9 +80,10 @@ class FileBatchRecordReaderTest extends BaseND4JTest {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Csv Sequence") @DisplayName("Test Csv Sequence")
void testCsvSequence(@TempDir Path testDir) throws Exception { void testCsvSequence(Nd4jBackend backend) throws Exception {
// CSV sequence - 3 lines per file, 10 files // CSV sequence - 3 lines per file, 10 files
File baseDir = testDir.toFile(); File baseDir = testDir.toFile();
List<File> fileList = new ArrayList<>(); List<File> fileList = new ArrayList<>();

View File

@ -21,7 +21,6 @@ package org.datavec.api.transform.ops;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;

View File

@ -60,10 +60,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -119,10 +119,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -59,10 +59,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -57,10 +57,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -65,10 +65,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -61,25 +61,18 @@
<artifactId>nd4j-common</artifactId> <artifactId>nd4j-common</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.datavec</groupId> <groupId>org.nd4j</groupId>
<artifactId>datavec-geo</artifactId> <artifactId>python4j-numpy</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-python</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -29,7 +29,6 @@ import org.datavec.api.transform.reduce.Reducer;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.python.PythonTransform;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -39,7 +38,6 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static java.time.Duration.ofMillis; import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout; import static org.junit.jupiter.api.Assertions.assertTimeout;
@ -166,37 +164,8 @@ class ExecutionTest {
List<List<Writable>> out = outRdd; List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out); out = new ArrayList<>(out);
Collections.sort(out, new Comparator<List<Writable>>() { Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
assertEquals(expOut, out); assertEquals(expOut, out);
} }
@Test
@Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
@DisplayName("Test Python Execution Ndarray")
void testPythonExecutionNdarray() {
assertTimeout(ofMillis(60000), () -> {
Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
List<List<Writable>> functions = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>();
INDArray firstArr = Nd4j.linspace(1, 4, 4);
INDArray secondArr = Nd4j.linspace(1, 4, 4);
firstRow.add(new NDArrayWritable(firstArr));
firstRow.add(new NDArrayWritable(secondArr));
functions.add(firstRow);
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray expected = Transforms.sin(firstArr);
INDArray secondExpected = Transforms.cos(secondArr);
assertEquals(expected, firstResult);
assertEquals(secondExpected, secondResult);
});
}
} }

View File

@ -1,155 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.local.transforms.transform;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.geo.LocationType;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author saudet
*/
public class TestGeoTransforms {
@BeforeAll
public static void beforeClass() throws Exception {
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
}
@AfterClass
public static void afterClass(){
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
}
@Test
public void testCoordinatesDistanceTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
.build();
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(4, out.numColumns());
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
out.getColumnTypes());
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
new DoubleWritable(Math.sqrt(160))),
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
new Text("10|5"))));
}
@Test
public void testIPAddressToCoordinatesTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("column").build();
Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
String in = "81.2.69.160";
double latitude = 51.5142;
double longitude = -0.0931;
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
assertEquals(2, coordinates.length);
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
//Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(transform);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ObjectInputStream ois = new ObjectInputStream(bais);
Transform deserialized = (Transform) ois.readObject();
writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
//System.out.println(Arrays.toString(coordinates));
assertEquals(2, coordinates.length);
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
}
@Test
public void testIPAddressToLocationTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("column").build();
LocationType[] locationTypes = LocationType.values();
String in = "81.2.69.160";
String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
"51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
for (int i = 0; i < locationTypes.length; i++) {
LocationType locationType = locationTypes[i];
String location = locations[i];
Transform transform = new IPAddressToLocationTransform("column", locationType);
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
assertEquals(location, writables.get(0).toString());
//System.out.println(location);
}
}
}

View File

@ -1,386 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.local.transforms.transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.condition.Condition;
import org.datavec.api.transform.filter.ConditionFilter;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.schema.Schema;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.datavec.api.writable.*;
import org.datavec.python.PythonCondition;
import org.datavec.python.PythonTransform;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.datavec.api.transform.schema.Schema.Builder;
import static org.junit.jupiter.api.Assertions.*;
@NotThreadSafe
public class TestPythonTransformProcess {
@Test()
public void testStringConcat() throws Exception{
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnString("col1")
.addColumnString("col2");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnString("col3");
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build();
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
List<Writable> outputs = tp.execute(inputs);
assertEquals((outputs.get(0)).toString(), "Hello ");
assertEquals((outputs.get(1)).toString(), "World!");
assertEquals((outputs.get(2)).toString(), "Hello World!");
}
@Test()
@Timeout(60000L)
public void testMixedTypes() throws Exception {
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
.addColumnString("col3")
.addColumnDouble("col4");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnInteger("col5");
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.inputSchema(initialSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(new IntWritable(10),
new FloatWritable(3.5f),
new Text("5"),
new DoubleWritable(2.0)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
}
@Test()
@Timeout(60000L)
public void testNDArray() throws Exception {
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape);
INDArray expectedOutput = arr1.add(arr2);
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test()
@Timeout(60000L)
public void testNDArray2() throws Exception {
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape);
INDArray expectedOutput = arr1.add(arr2);
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test()
@Timeout(60000L)
public void testNDArrayMixed() throws Exception{
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test()
@Timeout(60000L)
public void testPythonFilter() {
Schema schema = new Builder().addColumnInteger("column").build();
Condition condition = new PythonCondition(
"f = lambda: column < 0"
);
condition.setInputSchema(schema);
Filter filter = new ConditionFilter(condition);
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
}
@Test()
@Timeout(60000L)
public void testPythonFilterAndTransform() throws Exception {
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
.addColumnString("col3")
.addColumnDouble("col4");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnString("col6");
Schema finalSchema = schemaBuilder.build();
Condition condition = new PythonCondition(
"f = lambda: col1 < 0 and col2 > 10.0"
);
condition.setInputSchema(initialSchema);
Filter filter = new ConditionFilter(condition);
String pythonCode = "col6 = str(col1 + col2)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).filter(
filter
).build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(-3),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(11.2f),
new Text("abcd"),
new DoubleWritable(2.1))
);
LocalTransformExecutor.execute(inputs,tp);
}
@Test
public void testPythonTransformNoOutputSpecified() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
Schema inputSchema = new Builder()
.addColumnInteger("a")
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertEquals(3,execute.get(0).get(0).toInt());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testNumpyTransform() {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertNotNull(execute.get(0).get(1));
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testWithSetupRun() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n")
.returnAllInputs(true)
.setupAndRun(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.addColumnNDArray("b", new long[]{1, 1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
}
}

View File

@ -128,10 +128,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -92,6 +92,10 @@
<groupId>org.junit.jupiter</groupId> <groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId> <artifactId>junit-jupiter-api</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.junit.vintage</groupId> <groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId> <artifactId>junit-vintage-engine</artifactId>
@ -154,7 +158,7 @@
<skip>${skipTestResourceEnforcement}</skip> <skip>${skipTestResourceEnforcement}</skip>
<rules> <rules>
<requireActiveProfile> <requireActiveProfile>
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles> <profiles>nd4j-tests-cpu,nd4j-tests-cuda</profiles>
<all>false</all> <all>false</all>
</requireActiveProfile> </requireActiveProfile>
</rules> </rules>
@ -163,23 +167,6 @@
</execution> </execution>
</executions> </executions>
</plugin> </plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<argLine></argLine>
<!--
By default: Surefire will set the classpath based on the manifest. Because tests are not included
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
function correctly without this configuration.
For example, tests for custom transforms (where the custom transform is defined in the test directory)
will fail due to the custom transform not being found on the classpath.
http://maven.apache.org/surefire/maven-surefire-plugin/examples/class-loading.html
-->
<useSystemClassLoader>true</useSystemClassLoader>
<useManifestOnlyJar>false</useManifestOnlyJar>
</configuration>
</plugin>
<plugin> <plugin>
<groupId>org.eclipse.m2e</groupId> <groupId>org.eclipse.m2e</groupId>
<artifactId>lifecycle-mapping</artifactId> <artifactId>lifecycle-mapping</artifactId>
@ -249,7 +236,7 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
@ -266,7 +253,7 @@
</dependencies> </dependencies>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
@ -286,9 +273,6 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId> <artifactId>maven-surefire-plugin</artifactId>
<configuration>
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
</configuration>
</plugin> </plugin>
</plugins> </plugins>
</build> </build>

View File

@ -64,7 +64,7 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
@ -75,7 +75,7 @@
</dependencies> </dependencies>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>

View File

@ -56,10 +56,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -166,7 +166,7 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
@ -177,7 +177,7 @@
</dependencies> </dependencies>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>

View File

@ -23,7 +23,6 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -34,7 +33,6 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Early Termination Data Set Iterator Test") @DisplayName("Early Termination Data Set Iterator Test")
class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest { class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {

View File

@ -21,19 +21,16 @@ package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Early Termination Multi Data Set Iterator Test") @DisplayName("Early Termination Multi Data Set Iterator Test")

View File

@ -34,7 +34,6 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -46,7 +45,6 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Disabled @Disabled
@DisplayName("Attention Layer Test") @DisplayName("Attention Layer Test")

View File

@ -105,11 +105,12 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
<build> <build>
<plugins> <plugins>
<plugin> <plugin>
<artifactId>maven-surefire-plugin</artifactId> <artifactId>maven-surefire-plugin</artifactId>
<inherited>true</inherited>
<configuration> <configuration>
<skip>true</skip> <skip>true</skip>
</configuration> </configuration>
@ -118,7 +119,7 @@
</build> </build>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -56,10 +56,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -50,10 +50,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -45,10 +45,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -54,10 +54,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -112,7 +112,7 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
@ -123,7 +123,7 @@
</dependencies> </dependencies>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>

View File

@ -72,10 +72,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -306,7 +306,7 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
@ -317,7 +317,7 @@
</dependencies> </dependencies>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>

View File

@ -101,10 +101,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -49,10 +49,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -127,10 +127,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -102,7 +102,7 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
@ -113,7 +113,7 @@
</dependencies> </dependencies>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>

View File

@ -99,10 +99,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -44,10 +44,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -89,10 +89,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -88,10 +88,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -90,10 +90,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -105,10 +105,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -182,10 +182,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -77,10 +77,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -104,10 +104,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -141,10 +141,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -79,10 +79,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -426,10 +426,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -44,10 +44,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -87,10 +87,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -117,10 +117,10 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -143,6 +143,10 @@
</extensions> </extensions>
<plugins> <plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
</plugin>
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-enforcer-plugin</artifactId> <artifactId>maven-enforcer-plugin</artifactId>
@ -158,7 +162,7 @@
<skip>${skipBackendChoice}</skip> <skip>${skipBackendChoice}</skip>
<rules> <rules>
<requireActiveProfile> <requireActiveProfile>
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles> <profiles>nd4j-tests-cpu,nd4j-tests-cuda</profiles>
<all>false</all> <all>false</all>
</requireActiveProfile> </requireActiveProfile>
</rules> </rules>
@ -227,43 +231,6 @@
</plugin> </plugin>
</plugins> </plugins>
<pluginManagement>
<plugins>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<inherited>true</inherited>
<configuration>
<!--
By default: Surefire will set the classpath based on the manifest. Because tests are not included
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
function correctly without this configuration.
For example, tests for custom layers (where the custom layer is defined in the test directory)
will fail due to the custom layer not being found on the classpath.
http://maven.apache.org/surefire/maven-surefire-plugin/examples/class-loading.html
-->
<useSystemClassLoader>true</useSystemClassLoader>
<useManifestOnlyJar>false</useManifestOnlyJar>
<argLine> -Dfile.encoding=UTF-8 -Xmx8g "</argLine>
<includes>
<!-- Default setting only runs tests that start/end with "Test" -->
<include>*.java</include>
<include>**/*.java</include>
</includes>
</configuration>
<dependencies>
<dependency>
<groupId>org.apache.maven.surefire</groupId>
<artifactId>surefire-junit-platform</artifactId>
<version>${maven-surefire-plugin.version}</version>
</dependency>
</dependencies>
</plugin>
<plugin>
<groupId>org.eclipse.m2e</groupId>
<artifactId>lifecycle-mapping</artifactId>
</plugin>
</plugins>
</pluginManagement>
</build> </build>
<profiles> <profiles>
@ -290,10 +257,10 @@
<module>deeplearning4j-cuda</module> <module>deeplearning4j-cuda</module>
</modules> </modules>
</profile> </profile>
<!-- For running unit tests with nd4j-native: "mvn clean test -P test-nd4j-native" <!-- For running unit tests with nd4j-native: "mvn clean test -P nd4j-tests-cpu"
Note that this excludes DL4J-cuda --> Note that this excludes DL4J-cuda -->
<profile> <profile>
<id>test-nd4j-native</id> <id>nd4j-tests-cpu</id>
<activation> <activation>
<activeByDefault>false</activeByDefault> <activeByDefault>false</activeByDefault>
</activation> </activation>
@ -311,70 +278,10 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<inherited>true</inherited>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>${junit.version}</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<version>${junit.version}</version>
</dependency>
<dependency>
<groupId>org.apache.maven.surefire</groupId>
<artifactId>surefire-junit-platform</artifactId>
<version>${maven-surefire-plugin.version}</version>
</dependency>
</dependencies>
<configuration>
<environmentVariables>
</environmentVariables>
<testSourceDirectory>src/test/java</testSourceDirectory>
<includes>
<include>*.java</include>
<include>**/*.java</include>
<include>**/Test*.java</include>
<include>**/*Test.java</include>
<include>**/*TestCase.java</include>
</includes>
<junitArtifactName>org.junit.jupiter:junit-jupiter-engine</junitArtifactName>
<systemPropertyVariables>
<org.nd4j.linalg.defaultbackend>
org.nd4j.linalg.cpu.nativecpu.CpuBackend
</org.nd4j.linalg.defaultbackend>
<org.nd4j.linalg.tests.backendstorun>
org.nd4j.linalg.cpu.nativecpu.CpuBackend
</org.nd4j.linalg.tests.backendstorun>
</systemPropertyVariables>
<!--
Maximum heap size was set to 8g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough.
For testing large zoo models, this may not be enough (so comment it out).
-->
<argLine></argLine>
</configuration>
</plugin>
</plugins>
</build>
</profile> </profile>
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" --> <!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
<profile> <profile>
<id>test-nd4j-cuda-11.0</id> <id>nd4j-tests-cuda</id>
<activation> <activation>
<activeByDefault>false</activeByDefault> <activeByDefault>false</activeByDefault>
</activation> </activation>
@ -392,43 +299,6 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>
<!-- Default to ALL modules here, unlike nd4j-native -->
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>${maven-surefire-plugin.version}</version>
<configuration>
<environmentVariables>
</environmentVariables>
<testSourceDirectory>src/test/java</testSourceDirectory>
<includes>
<include>*.java</include>
<include>**/*.java</include>
<include>**/Test*.java</include>
<include>**/*Test.java</include>
<include>**/*TestCase.java</include>
</includes>
<junitArtifactName>org.junit.jupiter:junit-jupiter</junitArtifactName>
<systemPropertyVariables>
<org.nd4j.linalg.defaultbackend>
org.nd4j.linalg.jcublas.JCublasBackend
</org.nd4j.linalg.defaultbackend>
<org.nd4j.linalg.tests.backendstorun>
org.nd4j.linalg.jcublas.JCublasBackend
</org.nd4j.linalg.tests.backendstorun>
</systemPropertyVariables>
<!--
Maximum heap size was set to 6g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough.
-->
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
</configuration>
</plugin>
</plugins>
</build>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -5,7 +5,7 @@ Linux
[INFO] Total time: 14.610 s [INFO] Total time: 14.610 s
[INFO] Finished at: 2021-03-06T15:35:28+09:00 [INFO] Finished at: 2021-03-06T15:35:28+09:00
[INFO] ------------------------------------------------------------------------ [INFO] ------------------------------------------------------------------------
[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist. [WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1] [ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
[ERROR] [ERROR]
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch. [ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
@ -749,7 +749,7 @@ make[1]: Leaving directory '/c/Users/agibs/Documents/GitHub/eclipse-deeplearning
[INFO] Total time: 15.482 s [INFO] Total time: 15.482 s
[INFO] Finished at: 2021-03-06T15:27:35+09:00 [INFO] Finished at: 2021-03-06T15:27:35+09:00
[INFO] ------------------------------------------------------------------------ [INFO] ------------------------------------------------------------------------
[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist. [WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1] [ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
[ERROR] [ERROR]
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch. [ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.

View File

@ -0,0 +1,19 @@
Const,in_0
Const,while/Const
Const,while/add/y
Identity,in_0/read
Enter,while/Enter
Enter,while/Enter_1
Merge,while/Merge
Merge,while/Merge_1
Less,while/Less
LoopCond,while/LoopCond
Switch,while/Switch
Switch,while/Switch_1
Identity,while/Identity
Exit,while/Exit
Identity,while/Identity_1
Exit,while/Exit_1
Add,while/add
NextIteration,while/NextIteration_1
NextIteration,while/NextIteration

View File

@ -0,0 +1,16 @@
Identity,in_0/read
Enter,while/Enter
Enter,while/Enter_1
Merge,while/Merge
Merge,while/Merge_1
Less,while/Less
LoopCond,while/LoopCond
Switch,while/Switch
Switch,while/Switch_1
Identity,while/Identity
Exit,while/Exit
Identity,while/Identity_1
Exit,while/Exit_1
Add,while/add
NextIteration,while/NextIteration_1
NextIteration,while/NextIteration

View File

@ -0,0 +1,19 @@
in_0
while/Const
while/add/y
in_0/read
while/Enter
while/Enter_1
while/Merge
while/Merge_1
while/Less
while/LoopCond
while/Switch
while/Switch_1
while/Identity
while/Exit
while/Identity_1
while/Exit_1
while/add
while/NextIteration_1
while/NextIteration

View File

@ -303,7 +303,7 @@
For testing large zoo models, this may not be enough (so comment it out). For testing large zoo models, this may not be enough (so comment it out).
--> -->
<argLine>-Dfile.encoding=UTF-8 "</argLine> <argLine>-Dfile.encoding=UTF-8 </argLine>
</configuration> </configuration>
</plugin> </plugin>
</plugins> </plugins>

View File

@ -27,6 +27,7 @@ import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -482,7 +483,7 @@ public class LayerOpValidation extends BaseOpValidation {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv3d(Nd4jBackend backend) { public void testConv3d(Nd4jBackend backend, TestInfo testInfo) {
//Pooling3d, Conv3D, batch norm //Pooling3d, Conv3D, batch norm
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -573,7 +574,7 @@ public class LayerOpValidation extends BaseOpValidation {
tc.testName(msg); tc.testName(msg);
String error = OpValidation.validate(tc); String error = OpValidation.validate(tc);
if (error != null) { if (error != null) {
failed.add(name); failed.add(testInfo.getTestMethod().get().getName());
} }
} }
} }
@ -1353,7 +1354,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err, err); assertNull(err, err);
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) { public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
int nIn = 3; int nIn = 3;
@ -1382,7 +1384,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) { public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1405,7 +1408,8 @@ public class LayerOpValidation extends BaseOpValidation {
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) { public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -22,6 +22,7 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
@ -664,6 +665,7 @@ public class MiscOpValidation extends BaseOpValidation {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testMmulGradientManual(Nd4jBackend backend) { public void testMmulGradientManual(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);

View File

@ -69,7 +69,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@AfterEach @AfterEach
public void tearDown(Nd4jBackend backend) { public void tearDown() {
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
} }

View File

@ -28,6 +28,7 @@ import lombok.val;
import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.LUDecomposition;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
@ -83,7 +84,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConcat(Nd4jBackend backend) { public void testConcat(Nd4jBackend backend, TestInfo testInfo) {
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2}; // int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
int[] concatDim = new int[]{0, 0, 0}; int[] concatDim = new int[]{0, 0, 0};
List<List<int[]>> origShapes = new ArrayList<>(); List<List<int[]>> origShapes = new ArrayList<>();
@ -115,7 +116,7 @@ public class ShapeOpValidation extends BaseOpValidation {
String error = OpValidation.validate(tc); String error = OpValidation.validate(tc);
if(error != null){ if(error != null){
failed.add(name); failed.add(testInfo.getTestMethod().get().getName());
} }
} }
@ -285,7 +286,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSqueezeGradient(Nd4jBackend backend) { public void testSqueezeGradient(Nd4jBackend backend,TestInfo testInfo) {
val origShape = new long[]{3, 4, 5}; val origShape = new long[]{3, 4, 5};
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -339,7 +340,7 @@ public class ShapeOpValidation extends BaseOpValidation {
String error = OpValidation.validate(tc, true); String error = OpValidation.validate(tc, true);
if(error != null){ if(error != null){
failed.add(name); failed.add(testInfo.getTestMethod().get().getName());
} }
} }
} }
@ -580,8 +581,9 @@ public class ShapeOpValidation extends BaseOpValidation {
return Long.MAX_VALUE; return Long.MAX_VALUE;
} }
@Test() @ParameterizedTest
public void testStack(Nd4jBackend backend) { @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStack(Nd4jBackend backend,TestInfo testInfo) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>(); List<String> failed = new ArrayList<>();
@ -661,7 +663,7 @@ public class ShapeOpValidation extends BaseOpValidation {
String error = OpValidation.validate(tc); String error = OpValidation.validate(tc);
if(error != null){ if(error != null){
failed.add(name); failed.add(testInfo.getTestMethod().get().getName());
} }
} }
} }

View File

@ -72,6 +72,8 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends { public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override @Override
public char ordering(){ public char ordering(){
@ -82,7 +84,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testBasic(Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4); INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() ); SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
@ -121,7 +123,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
int numOutputs = fg.outputsLength(); int numOutputs = fg.outputsLength();
List<IntPair> outputs = new ArrayList<>(numOutputs); List<IntPair> outputs = new ArrayList<>(numOutputs);
for( int i=0; i<numOutputs; i++ ){ for( int i = 0; i < numOutputs; i++) {
outputs.add(fg.outputs(i)); outputs.add(fg.outputs(i));
} }
@ -138,7 +140,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testSimple(Nd4jBackend backend) throws Exception {
for( int i = 0; i < 10; i++ ) { for( int i = 0; i < 10; i++ ) {
for(boolean execFirst : new boolean[]{false, true}) { for(boolean execFirst : new boolean[]{false, true}) {
log.info("Starting test: i={}, execFirst={}", i, execFirst); log.info("Starting test: i={}, execFirst={}", i, execFirst);
@ -268,7 +270,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testTrainingSerde(Nd4jBackend backend) throws Exception {
//Ensure 2 things: //Ensure 2 things:
//1. Training config is serialized/deserialized correctly //1. Training config is serialized/deserialized correctly

View File

@ -109,7 +109,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
} }
@BeforeEach @BeforeEach
public void before(Nd4jBackend backend) { public void before() {
Nd4j.create(1); Nd4j.create(1);
initialType = Nd4j.dataType(); initialType = Nd4j.dataType();
@ -118,7 +118,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
} }
@AfterEach @AfterEach
public void after(Nd4jBackend backend) { public void after() {
Nd4j.setDataType(initialType); Nd4j.setDataType(initialType);
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);

View File

@ -21,9 +21,6 @@
package org.nd4j.autodiff.samediff.listeners; package org.nd4j.autodiff.samediff.listeners;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
@ -47,11 +44,11 @@ import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class CheckpointListenerTest extends BaseNd4jTestWithBackends { public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override @Override
@ -97,7 +94,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testCheckpointEveryEpoch(Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();
@ -132,7 +129,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testCheckpointEvery5Iter(Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();
@ -172,7 +169,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testCheckpointListenerEveryTimeUnit(Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();
@ -217,7 +214,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testCheckpointListenerKeepLast3AndEvery3(Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
SameDiff sd = getModel(); SameDiff sd = getModel();

View File

@ -23,6 +23,7 @@ package org.nd4j.autodiff.samediff.listeners;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
@ -49,6 +50,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
public class ProfilingListenerTest extends BaseNd4jTestWithBackends { public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override @Override
public char ordering() { public char ordering() {
@ -59,7 +61,8 @@ public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception { @Disabled
public void testProfilingListenerSimple(Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);

View File

@ -64,6 +64,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
public class FileReadWriteTests extends BaseNd4jTestWithBackends { public class FileReadWriteTests extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override @Override
public char ordering(){ public char ordering(){
@ -81,7 +82,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException { public void testSimple(Nd4jBackend backend) throws IOException {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4); SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
SDVariable sum = v.sum(); SDVariable sum = v.sum();
@ -163,7 +164,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
//Append a number of events //Append a number of events
w.registerEventName("accuracy"); w.registerEventName("accuracy");
for( int iter=0; iter<3; iter++) { for( int iter = 0; iter < 3; iter++) {
long t = System.currentTimeMillis(); long t = System.currentTimeMillis();
w.writeScalarEvent("accuracy", LogFileWriter.EventSubtype.EVALUATION, t, iter, 0, 0.5 + 0.1 * iter); w.writeScalarEvent("accuracy", LogFileWriter.EventSubtype.EVALUATION, t, iter, 0, 0.5 + 0.1 * iter);
} }
@ -175,7 +176,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
UIAddName addName = (UIAddName) events.get(0).getRight(); UIAddName addName = (UIAddName) events.get(0).getRight();
assertEquals("accuracy", addName.name()); assertEquals("accuracy", addName.name());
for( int i=1; i<4; i++ ){ for( int i = 1; i < 4; i++ ){
FlatArray fa = (FlatArray) events.get(i).getRight(); FlatArray fa = (FlatArray) events.get(i).getRight();
INDArray arr = Nd4j.createFromFlatArray(fa); INDArray arr = Nd4j.createFromFlatArray(fa);
@ -186,7 +187,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{ public void testNullBinLabels(Nd4jBackend backend) throws Exception{
File dir = testDir.toFile(); File dir = testDir.toFile();
File f = new File(dir, "temp.bin"); File f = new File(dir, "temp.bin");
LogFileWriter w = new LogFileWriter(f); LogFileWriter w = new LogFileWriter(f);

View File

@ -56,6 +56,8 @@ import static org.junit.jupiter.api.Assertions.*;
public class UIListenerTest extends BaseNd4jTestWithBackends { public class UIListenerTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';
@ -65,7 +67,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testUIListenerBasic(Nd4jBackend backend) throws Exception {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -102,7 +104,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testUIListenerContinue(Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet(); SameDiff sd1 = getSimpleNet();
@ -194,7 +196,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testUIListenerBadContinue(Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150); IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet(); SameDiff sd1 = getSimpleNet();
@ -275,7 +277,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
} }
private static SameDiff getSimpleNet(){ private static SameDiff getSimpleNet() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);

View File

@ -22,6 +22,7 @@ package org.nd4j.evaluation;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
@ -48,6 +49,7 @@ public class NewInstanceTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testNewInstances(Nd4jBackend backend) { public void testNewInstances(Nd4jBackend backend) {
boolean print = true; boolean print = true;
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -20,7 +20,7 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROC;
@ -42,14 +42,15 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class ROCBinaryTest extends BaseNd4jTestWithBackends { public class ROCBinaryTest extends BaseNd4jTestWithBackends {
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testROCBinary(Nd4jBackend backend) { public void testROCBinary(Nd4jBackend backend) {
//Compare ROCBinary to ROC class //Compare ROCBinary to ROC class
@ -144,7 +145,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBinaryMerging(Nd4jBackend backend) { public void testRocBinaryMerging(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact for (int nSteps : new int[]{30, 0}) { //0 == exact
@ -175,7 +176,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) { public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
@ -216,7 +217,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary3d(Nd4jBackend backend) { public void testROCBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -251,7 +252,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary4d(Nd4jBackend backend) { public void testROCBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -286,7 +287,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary3dMasking(Nd4jBackend backend) { public void testROCBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -348,7 +349,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary4dMasking(Nd4jBackend backend) { public void testROCBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -20,6 +20,7 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
@ -82,16 +83,16 @@ public class ROCTest extends BaseNd4jTestWithBackends {
expFPR.put(10 / 10.0, 0.0 / totalNegatives); expFPR.put(10 / 10.0, 0.0 / totalNegatives);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBasic(Nd4jBackend backend) { public void testRocBasic(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax) //2 outputs here - probability distribution over classes (softmax)
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601}, {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}}); {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
INDArray actual = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1}, INDArray actual = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
{0, 1}, {0, 1}}); {0, 1}, {0, 1}});
ROC roc = new ROC(10); ROC roc = new ROC(10);
roc.eval(actual, predictions); roc.eval(actual, predictions);
@ -126,15 +127,15 @@ public class ROCTest extends BaseNd4jTestWithBackends {
assertEquals(1.0, auc, 1e-6); assertEquals(1.0, auc, 1e-6);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBasicSingleClass(Nd4jBackend backend) { public void testRocBasicSingleClass(Nd4jBackend backend) {
//1 output here - single probability value (sigmoid) //1 output here - single probability value (sigmoid)
//add 0.001 to avoid numerical/rounding issues (float vs. double, etc) //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
INDArray predictions = INDArray predictions =
Nd4j.create(new double[] {0.001, 0.101, 0.201, 0.301, 0.401, 0.501, 0.601, 0.701, 0.801, 0.901}, Nd4j.create(new double[] {0.001, 0.101, 0.201, 0.301, 0.401, 0.501, 0.601, 0.701, 0.801, 0.901},
new int[] {10, 1}); new int[] {10, 1});
INDArray actual = Nd4j.create(new double[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}, new int[] {10, 1}); INDArray actual = Nd4j.create(new double[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}, new int[] {10, 1});
@ -165,7 +166,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRoc(Nd4jBackend backend) { public void testRoc(Nd4jBackend backend) {
//Previous tests allowed for a perfect classifier with right threshold... //Previous tests allowed for a perfect classifier with right threshold...
@ -173,7 +174,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}}); INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});
INDArray prediction = Nd4j.create(new double[][] {{0.199, 0.801}, {0.499, 0.501}, {0.399, 0.601}, INDArray prediction = Nd4j.create(new double[][] {{0.199, 0.801}, {0.499, 0.501}, {0.399, 0.601},
{0.799, 0.201}, {0.899, 0.101}}); {0.799, 0.201}, {0.899, 0.101}});
Map<Double, Double> expTPR = new HashMap<>(); Map<Double, Double> expTPR = new HashMap<>();
double totalPositives = 2.0; double totalPositives = 2.0;
@ -251,27 +252,27 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) { public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
//Same as first test... //Same as first test...
//2 outputs here - probability distribution over classes (softmax) //2 outputs here - probability distribution over classes (softmax)
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601}, {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}}); {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1}, INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
{0, 1}, {0, 1}}); {0, 1}, {0, 1}});
INDArray predictions3d = Nd4j.create(2, 2, 5); INDArray predictions3d = Nd4j.create(2, 2, 5);
INDArray firstTSp = INDArray firstTSp =
predictions3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose(); predictions3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
assertArrayEquals(new long[] {5, 2}, firstTSp.shape()); assertArrayEquals(new long[] {5, 2}, firstTSp.shape());
firstTSp.assign(predictions2d.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all())); firstTSp.assign(predictions2d.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
INDArray secondTSp = INDArray secondTSp =
predictions3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose(); predictions3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
assertArrayEquals(new long[] {5, 2}, secondTSp.shape()); assertArrayEquals(new long[] {5, 2}, secondTSp.shape());
secondTSp.assign(predictions2d.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all())); secondTSp.assign(predictions2d.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
@ -299,23 +300,23 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocTimeSeriesMasking(Nd4jBackend backend) { public void testRocTimeSeriesMasking(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax) //2 outputs here - probability distribution over classes (softmax)
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc) INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601}, {0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}}); {0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1}, INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
{0, 1}, {0, 1}}); {0, 1}, {0, 1}});
//Create time series data... first time series: length 4. Second time series: length 6 //Create time series data... first time series: length 4. Second time series: length 6
INDArray predictions3d = Nd4j.create(2, 2, 6); INDArray predictions3d = Nd4j.create(2, 2, 6);
INDArray tad = predictions3d.tensorAlongDimension(0, 1, 2).transpose(); INDArray tad = predictions3d.tensorAlongDimension(0, 1, 2).transpose();
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()) tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
.assign(predictions2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())); .assign(predictions2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
tad = predictions3d.tensorAlongDimension(1, 1, 2).transpose(); tad = predictions3d.tensorAlongDimension(1, 1, 2).transpose();
tad.assign(predictions2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all())); tad.assign(predictions2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
@ -324,7 +325,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
INDArray labels3d = Nd4j.create(2, 2, 6); INDArray labels3d = Nd4j.create(2, 2, 6);
tad = labels3d.tensorAlongDimension(0, 1, 2).transpose(); tad = labels3d.tensorAlongDimension(0, 1, 2).transpose();
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()) tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
.assign(actual2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())); .assign(actual2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
tad = labels3d.tensorAlongDimension(1, 1, 2).transpose(); tad = labels3d.tensorAlongDimension(1, 1, 2).transpose();
tad.assign(actual2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all())); tad.assign(actual2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
@ -350,7 +351,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) { public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -381,7 +382,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCompare2Vs3Classes(Nd4jBackend backend) { public void testCompare2Vs3Classes(Nd4jBackend backend) {
@ -431,7 +432,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMerging(Nd4jBackend backend) { public void testROCMerging(Nd4jBackend backend) {
int nArrays = 10; int nArrays = 10;
@ -477,7 +478,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMerging2(Nd4jBackend backend) { public void testROCMerging2(Nd4jBackend backend) {
int nArrays = 10; int nArrays = 10;
@ -523,7 +524,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMultiMerging(Nd4jBackend backend) { public void testROCMultiMerging(Nd4jBackend backend) {
@ -572,7 +573,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAUCPrecisionRecall(Nd4jBackend backend) { public void testAUCPrecisionRecall(Nd4jBackend backend) {
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob //Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
@ -620,7 +621,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocAucExact(Nd4jBackend backend) { public void testRocAucExact(Nd4jBackend backend) {
@ -681,20 +682,20 @@ public class ROCTest extends BaseNd4jTestWithBackends {
*/ */
double[] p = new double[] {0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503, 0.5955447, 0.96451452, double[] p = new double[] {0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503, 0.5955447, 0.96451452,
0.6531771, 0.74890664, 0.65356987, 0.74771481, 0.96130674, 0.0083883, 0.10644438, 0.29870371, 0.6531771, 0.74890664, 0.65356987, 0.74771481, 0.96130674, 0.0083883, 0.10644438, 0.29870371,
0.65641118, 0.80981255, 0.87217591, 0.9646476, 0.72368535, 0.64247533, 0.71745362, 0.46759901, 0.65641118, 0.80981255, 0.87217591, 0.9646476, 0.72368535, 0.64247533, 0.71745362, 0.46759901,
0.32558468, 0.43964461, 0.72968908, 0.99401459, 0.67687371, 0.79082252, 0.17091426}; 0.32558468, 0.43964461, 0.72968908, 0.99401459, 0.67687371, 0.79082252, 0.17091426};
double[] l = new double[] {1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, double[] l = new double[] {1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
0, 1}; 0, 1};
double[] fpr_skl = new double[] {0.0, 0.0, 0.15789474, 0.15789474, 0.31578947, 0.31578947, 0.52631579, double[] fpr_skl = new double[] {0.0, 0.0, 0.15789474, 0.15789474, 0.31578947, 0.31578947, 0.52631579,
0.52631579, 0.68421053, 0.68421053, 0.84210526, 0.84210526, 0.89473684, 0.89473684, 1.0}; 0.52631579, 0.68421053, 0.68421053, 0.84210526, 0.84210526, 0.89473684, 0.89473684, 1.0};
double[] tpr_skl = new double[] {0.0, 0.09090909, 0.09090909, 0.18181818, 0.18181818, 0.36363636, 0.36363636, double[] tpr_skl = new double[] {0.0, 0.09090909, 0.09090909, 0.18181818, 0.18181818, 0.36363636, 0.36363636,
0.45454545, 0.45454545, 0.72727273, 0.72727273, 0.90909091, 0.90909091, 1.0, 1.0}; 0.45454545, 0.45454545, 0.72727273, 0.72727273, 0.90909091, 0.90909091, 1.0, 1.0};
//Note the change to the last value: same TPR and FPR at 0.0083883 and 0.0 -> we add the 0.0 threshold edge case + combine with the previous one. Same result //Note the change to the last value: same TPR and FPR at 0.0083883 and 0.0 -> we add the 0.0 threshold edge case + combine with the previous one. Same result
double[] thr_skl = new double[] {1.0, 0.99401459, 0.96130674, 0.92961609, 0.79082252, 0.74771481, 0.67687371, double[] thr_skl = new double[] {1.0, 0.99401459, 0.96130674, 0.92961609, 0.79082252, 0.74771481, 0.67687371,
0.65641118, 0.64247533, 0.46759901, 0.31637555, 0.20456028, 0.18391881, 0.17091426, 0.0}; 0.65641118, 0.64247533, 0.46759901, 0.31637555, 0.20456028, 0.18391881, 0.17091426, 0.0};
INDArray prob = Nd4j.create(p, new int[] {30, 1}); INDArray prob = Nd4j.create(p, new int[] {30, 1});
INDArray label = Nd4j.create(l, new int[] {30, 1}); INDArray label = Nd4j.create(l, new int[] {30, 1});
@ -784,7 +785,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) { public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
@ -797,7 +798,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) { public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
double[] threshold = new double[101]; double[] threshold = new double[101];
@ -814,15 +815,15 @@ public class ROCTest extends BaseNd4jTestWithBackends {
PrecisionRecallCurve prc = new PrecisionRecallCurve(threshold, precision, recall, null, null, null, -1); PrecisionRecallCurve prc = new PrecisionRecallCurve(threshold, precision, recall, null, null, null, -1);
PrecisionRecallCurve.Point[] points = new PrecisionRecallCurve.Point[] { PrecisionRecallCurve.Point[] points = new PrecisionRecallCurve.Point[] {
//Test exact: //Test exact:
prc.getPointAtThreshold(0.05), prc.getPointAtPrecision(0.05), prc.getPointAtRecall(1 - 0.05), prc.getPointAtThreshold(0.05), prc.getPointAtPrecision(0.05), prc.getPointAtRecall(1 - 0.05),
//Test approximate (point doesn't exist exactly). When it doesn't exist: //Test approximate (point doesn't exist exactly). When it doesn't exist:
//Threshold: lowest threshold equal to or exceeding the specified threshold value //Threshold: lowest threshold equal to or exceeding the specified threshold value
//Precision: lowest threshold equal to or exceeding the specified precision value //Precision: lowest threshold equal to or exceeding the specified precision value
//Recall: highest threshold equal to or exceeding the specified recall value //Recall: highest threshold equal to or exceeding the specified recall value
prc.getPointAtThreshold(0.0495), prc.getPointAtPrecision(0.0495), prc.getPointAtThreshold(0.0495), prc.getPointAtPrecision(0.0495),
prc.getPointAtRecall(1 - 0.0505)}; prc.getPointAtRecall(1 - 0.0505)};
@ -834,7 +835,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) { public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
//Sanity check: values calculated from the confusion matrix should match the PR curve values //Sanity check: values calculated from the confusion matrix should match the PR curve values
@ -843,7 +844,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
ROC r = new ROC(0, removeRedundantPts); ROC r = new ROC(0, removeRedundantPts);
INDArray labels = Nd4j.getExecutioner() INDArray labels = Nd4j.getExecutioner()
.exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE,100, 1), 0.5)); .exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE,100, 1), 0.5));
INDArray probs = Nd4j.rand(100, 1); INDArray probs = Nd4j.rand(100, 1);
r.eval(labels, probs); r.eval(labels, probs);
@ -874,7 +875,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocMerge(){ public void testRocMerge(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -919,7 +920,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
assertEquals(auprc, auprcAct, 1e-6); assertEquals(auprc, auprcAct, 1e-6);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocMultiMerge(){ public void testRocMultiMerge(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -931,9 +932,9 @@ public class ROCTest extends BaseNd4jTestWithBackends {
int nOut = 5; int nOut = 5;
Random r = new Random(12345); Random r = new Random(12345);
for( int i=0; i<10; i++ ){ for( int i = 0; i < 10; i++ ){
INDArray labels = Nd4j.zeros(3, nOut); INDArray labels = Nd4j.zeros(3, nOut);
for( int j=0; j<3; j++ ){ for( int j = 0; j < 3; j++) {
labels.putScalar(j, r.nextInt(nOut), 1.0 ); labels.putScalar(j, r.nextInt(nOut), 1.0 );
} }
INDArray out = Nd4j.rand(3, nOut); INDArray out = Nd4j.rand(3, nOut);
@ -956,7 +957,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
roc1.merge(roc2); roc1.merge(roc2);
for( int i=0; i<nOut; i++ ) { for( int i = 0; i < nOut; i++) {
double aucExp = roc.calculateAUC(i); double aucExp = roc.calculateAUC(i);
double auprc = roc.calculateAUCPR(i); double auprc = roc.calculateAUCPR(i);
@ -969,9 +970,10 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBinaryMerge(){ @Disabled
public void testRocBinaryMerge(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
ROCBinary roc = new ROCBinary(); ROCBinary roc = new ROCBinary();
@ -980,7 +982,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
int nOut = 5; int nOut = 5;
for( int i=0; i<10; i++ ){ for( int i = 0; i < 10; i++) {
INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(3, nOut),0.5)); INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(3, nOut),0.5));
INDArray out = Nd4j.rand(3, nOut); INDArray out = Nd4j.rand(3, nOut);
out.diviColumnVector(out.sum(1)); out.diviColumnVector(out.sum(1));
@ -1015,7 +1017,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentationBinary(){ public void testSegmentationBinary(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
@ -1106,7 +1108,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentation(){ public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case

View File

@ -50,7 +50,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvalParameters(Nd4jBackend backend) { public void testEvalParameters(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
int specCols = 5; int specCols = 5;

View File

@ -152,7 +152,7 @@ public class LoneTest extends BaseNd4jTestWithBackends {
public void maskWhenMerge(Nd4jBackend backend) { public void maskWhenMerge(Nd4jBackend backend) {
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5)); DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3)); DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
List<DataSet> dataSetList = new ArrayList<DataSet>(); List<DataSet> dataSetList = new ArrayList<>();
dataSetList.add(dsA); dataSetList.add(dsA);
dataSetList.add(dsB); dataSetList.add(dsB);
DataSet fullDataSet = DataSet.merge(dataSetList); DataSet fullDataSet = DataSet.merge(dataSetList);
@ -175,7 +175,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
// System.out.println(b); // System.out.println(b);
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
//broken at a threshold //broken at a threshold
public void testArgMax(Nd4jBackend backend) { public void testArgMax(Nd4jBackend backend) {
int max = 63; int max = 63;
@ -263,7 +264,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
// log.info("p50: {}; avg: {};", times.get(times.size() / 2), time); // log.info("p50: {}; avg: {};", times.get(times.size() / 2), time);
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void checkIllegalElementOps(Nd4jBackend backend) { public void checkIllegalElementOps(Nd4jBackend backend) {
assertThrows(Exception.class,() -> { assertThrows(Exception.class,() -> {
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5); INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
@ -328,13 +330,13 @@ public class LoneTest extends BaseNd4jTestWithBackends {
reshaped.getDouble(i); reshaped.getDouble(i);
} }
for (int j=0;j<arr.slices();j++) { for (int j=0;j<arr.slices();j++) {
for (int k=0;k<arr.slice(j).length();k++) { for (int k = 0; k < arr.slice(j).length(); k++) {
// log.info("\nArr: slice " + j + " element " + k + " " + arr.slice(j).getDouble(k)); // log.info("\nArr: slice " + j + " element " + k + " " + arr.slice(j).getDouble(k));
arr.slice(j).getDouble(k); arr.slice(j).getDouble(k);
} }
} }
for (int j=0;j<reshaped.slices();j++) { for (int j = 0;j < reshaped.slices(); j++) {
for (int k=0;k<reshaped.slice(j).length();k++) { for (int k = 0;k < reshaped.slice(j).length(); k++) {
// log.info("\nReshaped: slice " + j + " element " + k + " " + reshaped.slice(j).getDouble(k)); // log.info("\nReshaped: slice " + j + " element " + k + " " + reshaped.slice(j).getDouble(k));
reshaped.slice(j).getDouble(k); reshaped.slice(j).getDouble(k);
} }

View File

@ -245,7 +245,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, false); INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, false);
assertEquals(sorted[1], sorted2); assertEquals(sorted[1], sorted2);
INDArray shouldIndex = Nd4j.create(new double[] {1, 1, 0, 0}, new long[] {2, 2}); INDArray shouldIndex = Nd4j.create(new double[] {1, 1, 0, 0}, new long[] {2, 2});
assertEquals(shouldIndex, sorted[0],getFailureMessage()); assertEquals(shouldIndex, sorted[0],getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@ -266,7 +266,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, true); INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, true);
assertEquals(sorted[1], sorted2); assertEquals(sorted[1], sorted2);
INDArray shouldIndex = Nd4j.create(new double[] {0, 0, 1, 1}, new long[] {2, 2}); INDArray shouldIndex = Nd4j.create(new double[] {0, 0, 1, 1}, new long[] {2, 2});
assertEquals(shouldIndex, sorted[0],getFailureMessage()); assertEquals(shouldIndex, sorted[0],getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@ -328,13 +328,13 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
public void testDivide(Nd4jBackend backend) { public void testDivide(Nd4jBackend backend) {
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2}); INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
INDArray div = two.div(two); INDArray div = two.div(two);
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage()); assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage(backend));
INDArray half = Nd4j.create(new float[] {0.5f, 0.5f, 0.5f, 0.5f}, new long[] {2, 2}); INDArray half = Nd4j.create(new float[] {0.5f, 0.5f, 0.5f, 0.5f}, new long[] {2, 2});
INDArray divi = Nd4j.create(new float[] {0.3f, 0.6f, 0.9f, 0.1f}, new long[] {2, 2}); INDArray divi = Nd4j.create(new float[] {0.3f, 0.6f, 0.9f, 0.1f}, new long[] {2, 2});
INDArray assertion = Nd4j.create(new float[] {1.6666666f, 0.8333333f, 0.5555556f, 5}, new long[] {2, 2}); INDArray assertion = Nd4j.create(new float[] {1.6666666f, 0.8333333f, 0.5555556f, 5}, new long[] {2, 2});
INDArray result = half.div(divi); INDArray result = half.div(divi);
assertEquals( assertion, result,getFailureMessage()); assertEquals( assertion, result,getFailureMessage(backend));
} }
@ -344,7 +344,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f}); INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
INDArray sigmoid = Transforms.sigmoid(n, false); INDArray sigmoid = Transforms.sigmoid(n, false);
assertEquals( assertion, sigmoid,getFailureMessage()); assertEquals( assertion, sigmoid,getFailureMessage(backend));
} }
@ -354,7 +354,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4}); INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
INDArray neg = Transforms.neg(n); INDArray neg = Transforms.neg(n);
assertEquals(assertion, neg,getFailureMessage()); assertEquals(assertion, neg,getFailureMessage(backend));
} }
@ -365,12 +365,12 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
double sim = Transforms.cosineSim(vec1, vec2); double sim = Transforms.cosineSim(vec1, vec2);
assertEquals(1, sim, 1e-1,getFailureMessage()); assertEquals(1, sim, 1e-1,getFailureMessage(backend));
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f}); INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f}); INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
sim = Transforms.cosineSim(vec3, vec4); sim = Transforms.cosineSim(vec3, vec4);
assertEquals(0.98, sim, 1e-1,getFailureMessage()); assertEquals(0.98, sim, 1e-1,getFailureMessage(backend));
} }
@ -621,7 +621,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray innerProduct = n.mmul(transposed); INDArray innerProduct = n.mmul(transposed);
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1); INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
assertEquals(scalar, innerProduct,getFailureMessage()); assertEquals(scalar, innerProduct,getFailureMessage(backend));
} }
@ -678,7 +678,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray five = Nd4j.ones(5); INDArray five = Nd4j.ones(5);
five.addi(five.dup()); five.addi(five.dup());
INDArray twos = Nd4j.valueArrayOf(5, 2); INDArray twos = Nd4j.valueArrayOf(5, 2);
assertEquals(twos, five,getFailureMessage()); assertEquals(twos, five,getFailureMessage(backend));
} }
@ -692,7 +692,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}}); INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
INDArray test = arr.mmul(arr.transpose()); INDArray test = arr.mmul(arr.transpose());
assertEquals(assertion, test,getFailureMessage()); assertEquals(assertion, test,getFailureMessage(backend));
} }
@ -704,7 +704,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
Nd4j.exec(new PrintVariable(newSlice)); Nd4j.exec(new PrintVariable(newSlice));
log.info("Slice: {}", newSlice); log.info("Slice: {}", newSlice);
n.putSlice(0, newSlice); n.putSlice(0, newSlice);
assertEquals( newSlice, n.slice(0),getFailureMessage()); assertEquals( newSlice, n.slice(0),getFailureMessage(backend));
} }
@ -713,7 +713,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
public void testRowVectorMultipleIndices(Nd4jBackend backend) { public void testRowVectorMultipleIndices(Nd4jBackend backend) {
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4); INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
linear.putScalar(new long[] {0, 1}, 1); linear.putScalar(new long[] {0, 1}, 1);
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage()); assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage(backend));
} }
@ -1059,7 +1059,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray nClone = n1.add(n2); INDArray nClone = n1.add(n2);
assertEquals(Nd4j.scalar(3), nClone); assertEquals(Nd4j.scalar(3), nClone);
INDArray n1PlusN2 = n1.add(n2); INDArray n1PlusN2 = n1.add(n2);
assertFalse(n1PlusN2.equals(n1),getFailureMessage()); assertFalse(n1PlusN2.equals(n1),getFailureMessage(backend));
INDArray n3 = Nd4j.scalar(3.0); INDArray n3 = Nd4j.scalar(3.0);
INDArray n4 = Nd4j.scalar(4.0); INDArray n4 = Nd4j.scalar(4.0);

View File

@ -156,7 +156,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType(); DataType initialType = Nd4j.dataType();
Level1 l1 = Nd4j.getBlasWrapper().level1(); Level1 l1 = Nd4j.getBlasWrapper().level1();
@TempDir Path testDir;
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
@ -255,7 +255,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSerialization(@TempDir Path testDir) throws Exception { public void testSerialization(Nd4jBackend backend) throws Exception {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray arr = Nd4j.rand(1, 20); INDArray arr = Nd4j.rand(1, 20);
@ -339,7 +339,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertArrayEquals(assertion,shapeTest); assertArrayEquals(assertion,shapeTest);
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled //temporary till libnd4j implements general broadcasting @Disabled //temporary till libnd4j implements general broadcasting
public void testAutoBroadcastAdd(Nd4jBackend backend) { public void testAutoBroadcastAdd(Nd4jBackend backend) {
INDArray left = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,1,2,1); INDArray left = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,1,2,1);
@ -370,9 +371,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
n.divi(Nd4j.scalar(1.0d)); n.divi(Nd4j.scalar(1.0d));
n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3}); n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
assertEquals(27, n.sumNumber().doubleValue(), 1e-1,getFailureMessage()); assertEquals(27, n.sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
INDArray a = n.slice(2); INDArray a = n.slice(2);
assertEquals( true, Arrays.equals(new long[] {3, 3}, a.shape()),getFailureMessage()); assertEquals( true, Arrays.equals(new long[] {3, 3}, a.shape()),getFailureMessage(backend));
} }
@ -478,12 +479,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}}); INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
INDArray test = arr.mmul(arr.transpose()); INDArray test = arr.mmul(arr.transpose());
assertEquals(assertion, test,getFailureMessage()); assertEquals(assertion, test,getFailureMessage(backend));
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled @Disabled
public void testMmulOp() throws Exception { public void testMmulOp(Nd4jBackend backend) throws Exception {
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
INDArray z = Nd4j.create(2, 2); INDArray z = Nd4j.create(2, 2);
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}}); INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
@ -494,7 +496,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
DynamicCustomOp op = new Mmul(arr, arr, z, mMulTranspose); DynamicCustomOp op = new Mmul(arr, arr, z, mMulTranspose);
Nd4j.getExecutioner().execAndReturn(op); Nd4j.getExecutioner().execAndReturn(op);
assertEquals(assertion, z,getFailureMessage()); assertEquals(assertion, z,getFailureMessage(backend));
} }
@ -505,7 +507,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray row1 = oneThroughFour.getRow(1).dup(); INDArray row1 = oneThroughFour.getRow(1).dup();
oneThroughFour.subiRowVector(row1); oneThroughFour.subiRowVector(row1);
INDArray result = Nd4j.create(new double[] {-2, -2, 0, 0}, new long[] {2, 2}); INDArray result = Nd4j.create(new double[] {-2, -2, 0, 0}, new long[] {2, 2});
assertEquals(result, oneThroughFour,getFailureMessage()); assertEquals(result, oneThroughFour,getFailureMessage(backend));
} }
@ -1093,7 +1095,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertTrue(expAllOnes.all()); assertTrue(expAllOnes.all());
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled @Disabled
public void testSumAlongDim1sEdgeCases(Nd4jBackend backend) { public void testSumAlongDim1sEdgeCases(Nd4jBackend backend) {
val shapes = new long[][] { val shapes = new long[][] {
@ -1227,7 +1230,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray row1 = oneThroughFour.getRow(1); INDArray row1 = oneThroughFour.getRow(1);
row1.addi(1); row1.addi(1);
INDArray result = Nd4j.create(new double[] {1, 2, 4, 5}, new long[] {2, 2}); INDArray result = Nd4j.create(new double[] {1, 2, 4, 5}, new long[] {2, 2});
assertEquals(result, oneThroughFour,getFailureMessage()); assertEquals(result, oneThroughFour,getFailureMessage(backend));
} }
@ -1241,8 +1244,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray linear = test.reshape(-1); INDArray linear = test.reshape(-1);
linear.putScalar(2, 6); linear.putScalar(2, 6);
linear.putScalar(3, 7); linear.putScalar(3, 7);
assertEquals(6, linear.getFloat(2), 1e-1,getFailureMessage()); assertEquals(6, linear.getFloat(2), 1e-1,getFailureMessage(backend));
assertEquals(7, linear.getFloat(3), 1e-1,getFailureMessage()); assertEquals(7, linear.getFloat(3), 1e-1,getFailureMessage(backend));
} }
@ -1609,7 +1612,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4}); INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
INDArray neg = Transforms.neg(n); INDArray neg = Transforms.neg(n);
assertEquals(assertion, neg,getFailureMessage()); assertEquals(assertion, neg,getFailureMessage(backend));
} }
@ -1622,13 +1625,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
double assertion = 5.47722557505; double assertion = 5.47722557505;
double norm3 = n.norm2Number().doubleValue(); double norm3 = n.norm2Number().doubleValue();
assertEquals(assertion, norm3, 1e-1,getFailureMessage()); assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray row1 = row.getRow(1); INDArray row1 = row.getRow(1);
double norm2 = row1.norm2Number().doubleValue(); double norm2 = row1.norm2Number().doubleValue();
double assertion2 = 5.0f; double assertion2 = 5.0f;
assertEquals(assertion2, norm2, 1e-1,getFailureMessage()); assertEquals(assertion2, norm2, 1e-1,getFailureMessage(backend));
Nd4j.setDataType(initialType); Nd4j.setDataType(initialType);
} }
@ -1640,14 +1643,14 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
float assertion = 5.47722557505f; float assertion = 5.47722557505f;
float norm3 = n.norm2Number().floatValue(); float norm3 = n.norm2Number().floatValue();
assertEquals(assertion, norm3, 1e-1,getFailureMessage()); assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
INDArray row = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray row = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray row1 = row.getRow(1); INDArray row1 = row.getRow(1);
float norm2 = row1.norm2Number().floatValue(); float norm2 = row1.norm2Number().floatValue();
float assertion2 = 5.0f; float assertion2 = 5.0f;
assertEquals(assertion2, norm2, 1e-1,getFailureMessage()); assertEquals(assertion2, norm2, 1e-1,getFailureMessage(backend));
} }
@ -1659,7 +1662,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
double sim = Transforms.cosineSim(vec1, vec2); double sim = Transforms.cosineSim(vec1, vec2);
assertEquals(1, sim, 1e-1,getFailureMessage()); assertEquals(1, sim, 1e-1,getFailureMessage(backend));
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f}); INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f}); INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
@ -1675,14 +1678,14 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
double assertion = 2; double assertion = 2;
INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8}); INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8});
INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer); INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer);
assertEquals(answer, scal,getFailureMessage()); assertEquals(answer, scal,getFailureMessage(backend));
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2}); INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray row1 = row.getRow(1); INDArray row1 = row.getRow(1);
double assertion2 = 5.0; double assertion2 = 5.0;
INDArray answer2 = Nd4j.create(new double[] {15, 20}); INDArray answer2 = Nd4j.create(new double[] {15, 20});
INDArray scal2 = Nd4j.getBlasWrapper().scal(assertion2, row1); INDArray scal2 = Nd4j.getBlasWrapper().scal(assertion2, row1);
assertEquals(answer2, scal2,getFailureMessage()); assertEquals(answer2, scal2,getFailureMessage(backend));
} }
@ -2076,17 +2079,17 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray innerProduct = n.mmul(transposed); INDArray innerProduct = n.mmul(transposed);
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1); INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
assertEquals(scalar, innerProduct,getFailureMessage()); assertEquals(scalar, innerProduct,getFailureMessage(backend));
INDArray outerProduct = transposed.mmul(n); INDArray outerProduct = transposed.mmul(n);
assertEquals(true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape()),getFailureMessage()); assertEquals(true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape()),getFailureMessage(backend));
INDArray three = Nd4j.create(new double[] {3, 4}); INDArray three = Nd4j.create(new double[] {3, 4});
INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2}); INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2});
INDArray sliceRow = test.slice(0).getRow(1); INDArray sliceRow = test.slice(0).getRow(1);
assertEquals(three, sliceRow,getFailureMessage()); assertEquals(three, sliceRow,getFailureMessage(backend));
INDArray twoSix = Nd4j.create(new double[] {2, 6}, new long[] {2, 1}); INDArray twoSix = Nd4j.create(new double[] {2, 6}, new long[] {2, 1});
INDArray threeTwoSix = three.mmul(twoSix); INDArray threeTwoSix = three.mmul(twoSix);
@ -2114,7 +2117,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray k1 = n1.transpose(); INDArray k1 = n1.transpose();
INDArray testVectorVector = k1.mmul(n1); INDArray testVectorVector = k1.mmul(n1);
assertEquals(vectorVector, testVectorVector,getFailureMessage()); assertEquals(vectorVector, testVectorVector,getFailureMessage(backend));
} }
@ -2204,7 +2207,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(linear.getDouble(0, 1), 1, 1e-1); assertEquals(linear.getDouble(0, 1), 1, 1e-1);
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSize(Nd4jBackend backend) { public void testSize(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
INDArray arr = Nd4j.create(4, 5); INDArray arr = Nd4j.create(4, 5);
@ -2357,7 +2361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled @Disabled
public void testTensorDot(Nd4jBackend backend) { public void testTensorDot(Nd4jBackend backend) {
INDArray oneThroughSixty = Nd4j.arange(60).reshape(3, 4, 5).castTo(DataType.DOUBLE); INDArray oneThroughSixty = Nd4j.arange(60).reshape(3, 4, 5).castTo(DataType.DOUBLE);
@ -3051,10 +3056,10 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
public void testMeans(Nd4jBackend backend) { public void testMeans(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray mean1 = a.mean(1); INDArray mean1 = a.mean(1);
assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage()); assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage(backend));
assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage()); assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage(backend));
assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage()); assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage()); assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
} }
@ -3063,9 +3068,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSums(Nd4jBackend backend) { public void testSums(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage()); assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage(backend));
assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage()); assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage(backend));
assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage()); assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
} }
@ -3438,7 +3443,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled @Disabled
public void largeInstantiation(Nd4jBackend backend) { public void largeInstantiation(Nd4jBackend backend) {
Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024 - 1)); // Still works; this can even be called as often as I want, allowing me even to spill over on disk Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024 - 1)); // Still works; this can even be called as often as I want, allowing me even to spill over on disk
@ -3487,7 +3493,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(cSum, fSum); //Expect: 4,6. Getting [4, 4] for f order assertEquals(cSum, fSum); //Expect: 4,6. Getting [4, 4] for f order
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled //not relevant anymore @Disabled //not relevant anymore
public void testAssignMixedC(Nd4jBackend backend) { public void testAssignMixedC(Nd4jBackend backend) {
int[] shape1 = {3, 2, 2, 2, 2, 2}; int[] shape1 = {3, 2, 2, 2, 2, 2};
@ -3787,7 +3794,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, result); assertEquals(assertion, result);
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation1(Nd4jBackend backend) { public void testPullRowsValidation1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2}); Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2});
@ -3795,7 +3803,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}); });
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation2(Nd4jBackend backend) { public void testPullRowsValidation2(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2}); Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2});
@ -3803,7 +3812,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}); });
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation3(Nd4jBackend backend) { public void testPullRowsValidation3(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10}); Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10});
@ -3811,7 +3821,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}); });
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation4(Nd4jBackend backend) { public void testPullRowsValidation4(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3}); Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3});
@ -3819,7 +3830,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}); });
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation5(Nd4jBackend backend) { public void testPullRowsValidation5(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e'); Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e');
@ -4975,7 +4987,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
} }
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTadReduce3_5(Nd4jBackend backend) { public void testTadReduce3_5(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> { assertThrows(ND4JIllegalStateException.class,() -> {
INDArray initial = Nd4j.create(5, 10); INDArray initial = Nd4j.create(5, 10);
@ -6004,7 +6017,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled @Disabled
public void testLogExpSum1(Nd4jBackend backend) { public void testLogExpSum1(Nd4jBackend backend) {
INDArray matrix = Nd4j.create(3, 3); INDArray matrix = Nd4j.create(3, 3);
@ -6019,7 +6033,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
} }
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled @Disabled
public void testLogExpSum2(Nd4jBackend backend) { public void testLogExpSum2(Nd4jBackend backend) {
INDArray row = Nd4j.create(new double[]{1, 2, 3}); INDArray row = Nd4j.create(new double[]{1, 2, 3});
@ -6246,7 +6261,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
} }
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReshapeFailure(Nd4jBackend backend) { public void testReshapeFailure(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> { assertThrows(ND4JIllegalStateException.class,() -> {
val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2); val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2);
@ -6345,7 +6361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertArrayEquals(new long[]{3, 2}, newShape.shape()); assertArrayEquals(new long[]{3, 2}, newShape.shape());
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTranspose1(Nd4jBackend backend) { public void testTranspose1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6}); val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
@ -6360,7 +6377,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTranspose2(Nd4jBackend backend) { public void testTranspose2(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
val scalar = Nd4j.scalar(2.f); val scalar = Nd4j.scalar(2.f);
@ -6375,7 +6393,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
//@Disabled //@Disabled
public void testMatmul_128by256(Nd4jBackend backend) { public void testMatmul_128by256(Nd4jBackend backend) {
val mA = Nd4j.create(128, 156).assign(1.0f); val mA = Nd4j.create(128, 156).assign(1.0f);
@ -6647,7 +6666,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp1, out1); assertEquals(exp1, out1);
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBadReduce3Call(Nd4jBackend backend) { public void testBadReduce3Call(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> { assertThrows(ND4JIllegalStateException.class,() -> {
val x = Nd4j.create(400,20); val x = Nd4j.create(400,20);
@ -7392,8 +7412,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(ez, z); assertEquals(ez, z);
} }
@Test() @ParameterizedTest
public void testBroadcastInvalid(){ @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBroadcastInvalid() {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
INDArray arr1 = Nd4j.ones(3,4,1); INDArray arr1 = Nd4j.ones(3,4,1);
@ -7656,7 +7677,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp, array); assertEquals(exp, array);
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScatterUpdateShortcut_f1(Nd4jBackend backend) { public void testScatterUpdateShortcut_f1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
val array = Nd4j.create(DataType.FLOAT, 5, 2); val array = Nd4j.create(DataType.FLOAT, 5, 2);
@ -8041,7 +8063,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp, out); //Failing here assertEquals(exp, out); //Failing here
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsFailure(Nd4jBackend backend) { public void testPullRowsFailure(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
val idxs = new int[]{0,2,3,4}; val idxs = new int[]{0,2,3,4};
@ -8144,7 +8167,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp1, out1); //This is OK assertEquals(exp1, out1); //This is OK
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutRowValidation(Nd4jBackend backend) { public void testPutRowValidation(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
val matrix = Nd4j.create(5, 10); val matrix = Nd4j.create(5, 10);
@ -8155,7 +8179,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutColumnValidation(Nd4jBackend backend) { public void testPutColumnValidation(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
val matrix = Nd4j.create(5, 10); val matrix = Nd4j.create(5, 10);
@ -8236,7 +8261,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScalarEq(){ public void testScalarEq(Nd4jBackend backend){
INDArray scalarRank2 = Nd4j.scalar(10.0).reshape(1,1); INDArray scalarRank2 = Nd4j.scalar(10.0).reshape(1,1);
INDArray scalarRank1 = Nd4j.scalar(10.0).reshape(1); INDArray scalarRank1 = Nd4j.scalar(10.0).reshape(1);
INDArray scalarRank0 = Nd4j.scalar(10.0); INDArray scalarRank0 = Nd4j.scalar(10.0);
@ -8273,7 +8298,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testType1(@TempDir Path testDir) throws IOException { @Disabled
public void testType1(Nd4jBackend backend) throws IOException {
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100}); INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100});
File dir = testDir.toFile(); File dir = testDir.toFile();
@ -8295,7 +8321,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOnes(){ public void testOnes(Nd4jBackend backend){
INDArray arr = Nd4j.ones(); INDArray arr = Nd4j.ones();
INDArray arr2 = Nd4j.ones(DataType.LONG); INDArray arr2 = Nd4j.ones(DataType.LONG);
assertEquals(0, arr.rank()); assertEquals(0, arr.rank());
@ -8306,7 +8332,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testZeros(){ public void testZeros(Nd4jBackend backend){
INDArray arr = Nd4j.zeros(); INDArray arr = Nd4j.zeros();
INDArray arr2 = Nd4j.zeros(DataType.LONG); INDArray arr2 = Nd4j.zeros(DataType.LONG);
assertEquals(0, arr.rank()); assertEquals(0, arr.rank());
@ -8317,7 +8343,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testType2(@TempDir Path testDir) throws IOException { @Disabled
public void testType2(Nd4jBackend backend) throws IOException {
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
INDArray in1 = Nd4j.ones(DataType.UINT16); INDArray in1 = Nd4j.ones(DataType.UINT16);
File dir = testDir.toFile(); File dir = testDir.toFile();

View File

@ -23,6 +23,7 @@ package org.nd4j.linalg;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
@ -58,11 +59,12 @@ public class ToStringTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testToStringScalars(){ @Disabled
public void testToStringScalars(Nd4jBackend backend){
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32}; DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"}; String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};
for(int dt=0; dt<5; dt++ ) { for(int dt = 0; dt < 5; dt++) {
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
long[] shape = ArrayUtil.nTimes(i, 1L); long[] shape = ArrayUtil.nTimes(i, 1L);
INDArray scalar = Nd4j.scalar(1.0f).castTo(dataTypes[dt]).reshape(shape); INDArray scalar = Nd4j.scalar(1.0f).castTo(dataTypes[dt]).reshape(shape);

View File

@ -64,7 +64,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
} }
@Test
@Disabled @Disabled
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@ -79,7 +78,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
} }
@Test
@Disabled @Disabled
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@ -100,7 +98,8 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
} }
@Test @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCreateNpy3(Nd4jBackend backend) throws Exception { public void testCreateNpy3(Nd4jBackend backend) throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile()); INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
assertEquals(8, arrCreate.length()); assertEquals(8, arrCreate.length());
@ -111,8 +110,9 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
assertEquals(arrCreate.data().address(), pointer.address()); assertEquals(arrCreate.data().address(), pointer.address());
} }
@Test
@Disabled // this is endless test @Disabled // this is endless test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEndlessAllocation(Nd4jBackend backend) { public void testEndlessAllocation(Nd4jBackend backend) {
Nd4j.getEnvironment().setMaxSpecialMemory(1); Nd4j.getEnvironment().setMaxSpecialMemory(1);
while (true) { while (true) {
@ -121,9 +121,10 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
} }
} }
@Test
@Disabled("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time") @Disabled("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time")
public void testAllocationLimits() throws Exception { @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAllocationLimits(Nd4jBackend backend) throws Exception {
Nd4j.create(1); Nd4j.create(1);
val origDeviceLimit = Nd4j.getEnvironment().getDeviceLimit(0); val origDeviceLimit = Nd4j.getEnvironment().getDeviceLimit(0);

View File

@ -20,7 +20,6 @@
package org.nd4j.linalg.api; package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.BaseNd4jTestWithBackends;

View File

@ -59,7 +59,7 @@ public class Level1Test extends BaseNd4jTestWithBackends {
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray row = matrix.getRow(1); INDArray row = matrix.getRow(1);
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row); Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage()); assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage(backend));
} }

View File

@ -70,8 +70,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
/** /**
* Testing level1 blas * Testing level1 blas
*/ */
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation1(Nd4jBackend backend) { public void testBlasValidation1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> { assertThrows(ND4JIllegalStateException.class,() -> {
@ -89,8 +88,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
/** /**
* Testing level2 blas * Testing level2 blas
*/ */
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation2(Nd4jBackend backend) { public void testBlasValidation2(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> { assertThrows(RuntimeException.class,() -> {
@ -109,8 +107,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
/** /**
* Testing level3 blas * Testing level3 blas
*/ */
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation3(Nd4jBackend backend) { public void testBlasValidation3(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {

View File

@ -88,7 +88,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
float[] d1 = new float[] {1, 2, 3, 4}; float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1); DataBuffer d = Nd4j.createBuffer(d1);
float[] d2 = d.asFloat(); float[] d2 = d.asFloat();
assertArrayEquals( d1, d2, 1e-1f,getFailureMessage()); assertArrayEquals( d1, d2, 1e-1f,getFailureMessage(backend));
} }
@ -146,7 +146,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
d.put(0, 0.0); d.put(0, 0.0);
float[] result = new float[] {0, 2, 3, 4}; float[] result = new float[] {0, 2, 3, 4};
d1 = d.asFloat(); d1 = d.asFloat();
assertArrayEquals(d1, result, 1e-1f,getFailureMessage()); assertArrayEquals(d1, result, 1e-1f,getFailureMessage(backend));
} }
@ -156,12 +156,12 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(0, 3); float[] get = buffer.getFloatsAt(0, 3);
float[] data = new float[] {1, 2, 3}; float[] data = new float[] {1, 2, 3};
assertArrayEquals(get, data, 1e-1f,getFailureMessage()); assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
float[] get2 = buffer.asFloat(); float[] get2 = buffer.asFloat();
float[] allData = buffer.getFloatsAt(0, (int) buffer.length()); float[] allData = buffer.getFloatsAt(0, (int) buffer.length());
assertArrayEquals(get2, allData, 1e-1f,getFailureMessage()); assertArrayEquals(get2, allData, 1e-1f,getFailureMessage(backend));
} }
@ -173,13 +173,13 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data(); DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(1, 3); float[] get = buffer.getFloatsAt(1, 3);
float[] data = new float[] {2, 3, 4}; float[] data = new float[] {2, 3, 4};
assertArrayEquals(get, data, 1e-1f,getFailureMessage()); assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
float[] allButLast = new float[] {2, 3, 4, 5}; float[] allButLast = new float[] {2, 3, 4, 5};
float[] allData = buffer.getFloatsAt(1, (int) buffer.length()); float[] allData = buffer.getFloatsAt(1, (int) buffer.length());
assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage()); assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage(backend));
} }
@ -190,7 +190,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
public void testAsBytes(Nd4jBackend backend) { public void testAsBytes(Nd4jBackend backend) {
INDArray arr = Nd4j.create(5); INDArray arr = Nd4j.create(5);
byte[] d = arr.data().asBytes(); byte[] d = arr.data().asBytes();
assertEquals(4 * 5, d.length,getFailureMessage()); assertEquals(4 * 5, d.length,getFailureMessage(backend));
INDArray rand = Nd4j.rand(3, 3); INDArray rand = Nd4j.rand(3, 3);
rand.data().asBytes(); rand.data().asBytes();

View File

@ -20,26 +20,18 @@
package org.nd4j.linalg.api.indexing; package org.nd4j.linalg.api.indexing;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.*;
import org.nd4j.linalg.indexing.IntervalIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndexAll;
import org.nd4j.linalg.indexing.NewAxis;
import org.nd4j.linalg.indexing.PointIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.common.util.ArrayUtil;
import java.util.Arrays; import java.util.Arrays;
import java.util.Random; import java.util.Random;
@ -56,22 +48,22 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNegativeBounds() { public void testNegativeBounds(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5); INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1)); INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
INDArray get = arr.get(NDArrayIndex.all(),interval); INDArray get = arr.get(NDArrayIndex.all(),interval);
INDArray assertion = Nd4j.create(new double[][]{ INDArray assertion = Nd4j.create(new double[][]{
{1,2,3}, {1,2,3},
{6,7,8} {6,7,8}
}); });
assertEquals(assertion,get); assertEquals(assertion,get);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNewAxis() { public void testNewAxis(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2); INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all()); INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
long[] shapeAssertion = {3, 2, 1, 1, 2}; long[] shapeAssertion = {3, 2, 1, 1, 2};
@ -79,9 +71,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void broadcastBug() { public void broadcastBug(Nd4jBackend backend) {
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2}); INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0)); final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
@ -91,9 +83,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIntervalsIn3D() { public void testIntervalsIn3D(Nd4jBackend backend) {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE); INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2); INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
INDArray rest = arr.get(interval(1, 2), interval(0, 2), interval(0, 2)); INDArray rest = arr.get(interval(1, 2), interval(0, 2), interval(0, 2));
@ -101,9 +93,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSmallInterval() { public void testSmallInterval(Nd4jBackend backend) {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE); INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2); INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
INDArray rest = arr.get(interval(1, 2), all(), all()); INDArray rest = arr.get(interval(1, 2), all(), all());
@ -111,9 +103,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAllWithNewAxisAndInterval() { public void testAllWithNewAxisAndInterval(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3); INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
@ -121,9 +113,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion2, get2); assertEquals(assertion2, get2);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAllWithNewAxisInMiddle() { public void testAllWithNewAxisInMiddle(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3); INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
@ -131,20 +123,20 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion2, get2); assertEquals(assertion2, get2);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAllWithNewAxis() { public void testAllWithNewAxis(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray get = arr.get(newAxis(), all(), point(1)); INDArray get = arr.get(newAxis(), all(), point(1));
INDArray assertion = Nd4j.create(new double[][] {{4, 5, 6}, {10, 11, 12}, {16, 17, 18}, {22, 23, 24}}) INDArray assertion = Nd4j.create(new double[][] {{4, 5, 6}, {10, 11, 12}, {16, 17, 18}, {22, 23, 24}})
.reshape(1, 4, 3); .reshape(1, 4, 3);
assertEquals(assertion, get); assertEquals(assertion, get);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIndexingWithMmul() { public void testIndexingWithMmul(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3); INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1); INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
// System.out.println(b); // System.out.println(b);
@ -154,9 +146,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, c); assertEquals(assertion, c);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPointPointInterval() { public void testPointPointInterval(Nd4jBackend backend) {
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3); INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3)); INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
INDArray assertion = Nd4j.create(new double[][] {{5, 6}, {8, 9}}); INDArray assertion = Nd4j.create(new double[][] {{5, 6}, {8, 9}});
@ -164,9 +156,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, get); assertEquals(assertion, get);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIntervalLowerBound() { public void testIntervalLowerBound(Nd4jBackend backend) {
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3); INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2)); INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}}); INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}});
@ -176,9 +168,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetPointRowVector() { public void testGetPointRowVector(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1); INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
INDArray arr2 = arr.get(point(0), interval(0, 100)); INDArray arr2 = arr.get(point(0), interval(0, 100));
@ -187,9 +179,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2); assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSpecifiedIndexVector() { public void testSpecifiedIndexVector(Nd4jBackend backend) {
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4); INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2); INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
INDArray get = rootMatrix.get(all(), new SpecifiedIndex(0, 2)); INDArray get = rootMatrix.get(all(), new SpecifiedIndex(0, 2));
@ -205,9 +197,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutRowIndexing() { public void testPutRowIndexing(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(1, 10); INDArray arr = Nd4j.ones(1, 10);
INDArray row = Nd4j.create(1, 10); INDArray row = Nd4j.create(1, 10);
@ -216,9 +208,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(arr, row); assertEquals(arr, row);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVectorIndexing2() { public void testVectorIndexing2(Nd4jBackend backend) {
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true)); INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
INDArray assertion = Nd4j.create(new double[] {2, 4}); INDArray assertion = Nd4j.create(new double[] {2, 4});
assertEquals(assertion, wholeVector); assertEquals(assertion, wholeVector);
@ -232,9 +224,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOffsetsC() { public void testOffsetsC(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
assertEquals(3, NDArrayIndex.offset(arr, 1, 1)); assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
assertEquals(3, NDArrayIndex.offset(arr, point(1), point(1))); assertEquals(3, NDArrayIndex.offset(arr, point(1), point(1)));
@ -249,9 +241,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIndexFor() { public void testIndexFor(Nd4jBackend backend) {
long[] shape = {1, 2}; long[] shape = {1, 2};
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape); INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
for (int i = 0; i < indexes.length; i++) { for (int i = 0; i < indexes.length; i++) {
@ -259,9 +251,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
} }
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetScalar() { public void testGetScalar(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(point(1)); INDArray d = arr.get(point(1));
assertTrue(d.isScalar()); assertTrue(d.isScalar());
@ -269,26 +261,26 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVectorIndexing() { public void testVectorIndexing(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1); INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5}); INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
INDArray viewTest = arr.get(point(0), interval(1, 5)); INDArray viewTest = arr.get(point(0), interval(1, 5));
assertEquals(assertion, viewTest); assertEquals(assertion, viewTest);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNegativeIndices() { public void testNegativeIndices(Nd4jBackend backend) {
INDArray test = Nd4j.create(10, 10, 10); INDArray test = Nd4j.create(10, 10, 10);
test.putScalar(new int[] {0, 0, -1}, 1.0); test.putScalar(new int[] {0, 0, -1}, 1.0);
assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber()); assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber());
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetIndices2d() { public void testGetIndices2d(Nd4jBackend backend) {
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2); INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
INDArray firstRow = twoByTwo.getRow(0); INDArray firstRow = twoByTwo.getRow(0);
INDArray secondRow = twoByTwo.getRow(1); INDArray secondRow = twoByTwo.getRow(1);
@ -305,9 +297,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement); assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetRow() { public void testGetRow(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5); INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
int[] toGet = {0, 1}; int[] toGet = {0, 1};
@ -323,9 +315,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetRowEdgeCase() { public void testGetRowEdgeCase(Nd4jBackend backend) {
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1); INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
INDArray get = rowVec.getRow(0); //Returning shape [1,1] INDArray get = rowVec.getRow(0); //Returning shape [1,1]
@ -333,9 +325,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(rowVec, get); assertEquals(rowVec, get);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetColumnEdgeCase() { public void testGetColumnEdgeCase(Nd4jBackend backend) {
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose(); INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray get = colVec.getColumn(0); //Returning shape [1,1] INDArray get = colVec.getColumn(0); //Returning shape [1,1]
@ -343,9 +335,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(colVec, get); assertEquals(colVec, get);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConcatColumns() { public void testConcatColumns(Nd4jBackend backend) {
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE); INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE); INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
INDArray concat = Nd4j.concat(1, input1, input2); INDArray concat = Nd4j.concat(1, input1, input2);
@ -353,18 +345,18 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, concat); assertEquals(assertion, concat);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetIndicesVector() { public void testGetIndicesVector(Nd4jBackend backend) {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1); INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3}); INDArray test = Nd4j.create(new double[] {2, 3});
INDArray result = line.get(point(0), interval(1, 3)); INDArray result = line.get(point(0), interval(1, 3));
assertEquals(test, result); assertEquals(test, result);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testArangeMul() { public void testArangeMul(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE); INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = interval(0, 2); INDArrayIndex index = interval(0, 2);
INDArray get = arange.get(index, index); INDArray get = arange.get(index, index);
@ -374,7 +366,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, mul); assertEquals(assertion, mul);
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIndexingThorough(){ public void testIndexingThorough(){
long[] fullShape = {3,4,5,6,7}; long[] fullShape = {3,4,5,6,7};
@ -575,7 +567,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
return d; return d;
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void debugging(){ public void debugging(){
long[] inShape = {3,4}; long[] inShape = {3,4};

View File

@ -46,12 +46,13 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j @Slf4j
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends { public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void compareAfterWrite(Nd4jBackend backend) throws Exception {
int [] ranksToCheck = new int[] {0,1,2,3,4}; int [] ranksToCheck = new int[] {0,1,2,3,4};
for (int i = 0; i < ranksToCheck.length; i++) { for (int i = 0; i < ranksToCheck.length; i++) {
// log.info("Checking read write arrays with rank " + ranksToCheck[i]); // log.info("Checking read write arrays with rank " + ranksToCheck[i]);
@ -82,7 +83,7 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNd4jReadWriteText(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testNd4jReadWriteText(Nd4jBackend backend) throws Exception {
File dir = testDir.toFile(); File dir = testDir.toFile();
int count = 0; int count = 0;

View File

@ -38,11 +38,11 @@ import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays;
@Slf4j @Slf4j
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends { public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void compareAfterWrite(Nd4jBackend backend) throws Exception {
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4}; int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
for (int i = 0; i < ranksToCheck.length; i++) { for (int i = 0; i < ranksToCheck.length; i++) {
log.info("Checking read write arrays with rank " + ranksToCheck[i]); log.info("Checking read write arrays with rank " + ranksToCheck[i]);

View File

@ -22,6 +22,7 @@ package org.nd4j.linalg.broadcast;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
@ -135,7 +136,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
assertEquals(e, z); assertEquals(e, z);
} }
@Test()
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_1(Nd4jBackend backend) { public void basicBroadcastFailureTest_1(Nd4jBackend backend) {
@ -146,7 +146,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
}); });
} }
@Test()
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_2(Nd4jBackend backend) { public void basicBroadcastFailureTest_2(Nd4jBackend backend) {
@ -158,7 +157,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
} }
@Test()
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_3(Nd4jBackend backend) { public void basicBroadcastFailureTest_3(Nd4jBackend backend) {
@ -170,16 +168,15 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
} }
@Test()
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void basicBroadcastFailureTest_4(Nd4jBackend backend) { public void basicBroadcastFailureTest_4(Nd4jBackend backend) {
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
val z = x.addi(y); val z = x.addi(y);
} }
@Test()
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_5(Nd4jBackend backend) { public void basicBroadcastFailureTest_5(Nd4jBackend backend) {
@ -191,7 +188,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
} }
@Test()
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_6(Nd4jBackend backend) { public void basicBroadcastFailureTest_6(Nd4jBackend backend) {
@ -249,9 +245,9 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
assertEquals(y, z); assertEquals(y, z);
} }
@Test()
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void emptyBroadcastTest_2(Nd4jBackend backend) { public void emptyBroadcastTest_2(Nd4jBackend backend) {
val x = Nd4j.create(DataType.FLOAT, 1, 2); val x = Nd4j.create(DataType.FLOAT, 1, 2);
val y = Nd4j.create(DataType.FLOAT, 0, 2); val y = Nd4j.create(DataType.FLOAT, 0, 2);

View File

@ -37,7 +37,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class CompressionMagicTests extends BaseNd4jTestWithBackends { public class CompressionMagicTests extends BaseNd4jTestWithBackends {
@BeforeEach @BeforeEach
public void setUp(Nd4jBackend backend) { public void setUp() {
} }

View File

@ -48,6 +48,7 @@ import java.util.Set;
public class DeconvTests extends BaseNd4jTestWithBackends { public class DeconvTests extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override @Override
public char ordering() { public char ordering() {
@ -56,7 +57,7 @@ public class DeconvTests extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void compareKeras(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void compareKeras(Nd4jBackend backend) throws Exception {
File newFolder = testDir.toFile(); File newFolder = testDir.toFile();
new ClassPathResource("keras/deconv/").copyDirectory(newFolder); new ClassPathResource("keras/deconv/").copyDirectory(newFolder);

View File

@ -99,7 +99,8 @@ public class SpecialTests extends BaseNd4jTestWithBackends {
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScalarShuffle1(Nd4jBackend backend) { public void testScalarShuffle1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> { assertThrows(ND4JIllegalStateException.class,() -> {
List<DataSet> listData = new ArrayList<>(); List<DataSet> listData = new ArrayList<>();

View File

@ -195,7 +195,8 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
assertEquals(exp, arrayX); assertEquals(exp, arrayX);
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInplaceOp1(Nd4jBackend backend) { public void testInplaceOp1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> { assertThrows(ND4JIllegalStateException.class,() -> {
val arrayX = Nd4j.create(10, 10); val arrayX = Nd4j.create(10, 10);

View File

@ -41,10 +41,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends { public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBalance(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testBalance(Nd4jBackend backend) throws Exception {
DataSetIterator iterator = new IrisDataSetIterator(10, 150); DataSetIterator iterator = new IrisDataSetIterator(10, 150);
File minibatches = new File(testDir.toFile(),"mini-batch-dir"); File minibatches = new File(testDir.toFile(),"mini-batch-dir");
@ -62,7 +63,7 @@ public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMiniBatchBalanced(@TempDir Path testDir,Nd4jBackend backend) throws Exception { public void testMiniBatchBalanced(Nd4jBackend backend) throws Exception {
int miniBatchSize = 100; int miniBatchSize = 100;
DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150); DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150);

View File

@ -51,8 +51,10 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
@Slf4j @Slf4j
public class DataSetTest extends BaseNd4jTestWithBackends { public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @TempDir Path testDir;
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testViewIterator(Nd4jBackend backend) { public void testViewIterator(Nd4jBackend backend) {
DataSetIterator iter = new ViewIterator(new IrisDataSetIterator(150, 150).next(), 10); DataSetIterator iter = new ViewIterator(new IrisDataSetIterator(150, 150).next(), 10);
@ -106,9 +108,9 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSplitTestAndTrain (Nd4jBackend backend) { public void testSplitTestAndTrain(Nd4jBackend backend) {
INDArray labels = FeatureUtil.toOutcomeMatrix(new int[] {0, 0, 0, 0, 0, 0, 0, 0}, 1); INDArray labels = FeatureUtil.toOutcomeMatrix(new int[] {0, 0, 0, 0, 0, 0, 0, 0}, 1);
DataSet data = new DataSet(Nd4j.rand(8, 1), labels); DataSet data = new DataSet(Nd4j.rand(8, 1), labels);
@ -116,7 +118,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
assertEquals(train.getTrain().getLabels().length(), 6); assertEquals(train.getTrain().getLabels().length(), 6);
SplitTestAndTrain train2 = data.splitTestAndTrain(6, new Random(1)); SplitTestAndTrain train2 = data.splitTestAndTrain(6, new Random(1));
assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage()); assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage(backend));
DataSet x0 = new IrisDataSetIterator(150, 150).next(); DataSet x0 = new IrisDataSetIterator(150, 150).next();
SplitTestAndTrain testAndTrain = x0.splitTestAndTrain(10); SplitTestAndTrain testAndTrain = x0.splitTestAndTrain(10);
@ -144,7 +146,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
SplitTestAndTrain testAndTrainRng = x2.splitTestAndTrain(10, rngHere); SplitTestAndTrain testAndTrainRng = x2.splitTestAndTrain(10, rngHere);
assertArrayEquals(testAndTrainRng.getTrain().getFeatures().shape(), assertArrayEquals(testAndTrainRng.getTrain().getFeatures().shape(),
testAndTrain.getTrain().getFeatures().shape()); testAndTrain.getTrain().getFeatures().shape());
assertEquals(testAndTrainRng.getTrain().getFeatures(), testAndTrain.getTrain().getFeatures()); assertEquals(testAndTrainRng.getTrain().getFeatures(), testAndTrain.getTrain().getFeatures());
assertEquals(testAndTrainRng.getTrain().getLabels(), testAndTrain.getTrain().getLabels()); assertEquals(testAndTrainRng.getTrain().getLabels(), testAndTrain.getTrain().getLabels());
@ -154,13 +156,13 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLabelCounts(Nd4jBackend backend) { public void testLabelCounts(Nd4jBackend backend) {
DataSet x0 = new IrisDataSetIterator(150, 150).next(); DataSet x0 = new IrisDataSetIterator(150, 150).next();
assertEquals(0, x0.get(0).outcome(),getFailureMessage()); assertEquals(0, x0.get(0).outcome(),getFailureMessage(backend));
assertEquals( 0, x0.get(1).outcome(),getFailureMessage()); assertEquals( 0, x0.get(1).outcome(),getFailureMessage(backend));
assertEquals(2, x0.get(149).outcome(),getFailureMessage()); assertEquals(2, x0.get(149).outcome(),getFailureMessage(backend));
Map<Integer, Double> counts = x0.labelCounts(); Map<Integer, Double> counts = x0.labelCounts();
assertEquals(50, counts.get(0), 1e-1,getFailureMessage()); assertEquals(50, counts.get(0), 1e-1,getFailureMessage(backend));
assertEquals(50, counts.get(1), 1e-1,getFailureMessage()); assertEquals(50, counts.get(1), 1e-1,getFailureMessage(backend));
assertEquals(50, counts.get(2), 1e-1,getFailureMessage()); assertEquals(50, counts.get(2), 1e-1,getFailureMessage(backend));
} }
@ -694,14 +696,14 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
INDArray expLabels3d = Nd4j.create(3, 3, 4); INDArray expLabels3d = Nd4j.create(3, 3, 4);
expLabels3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, expLabels3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
l3d1); l3d1);
expLabels3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(), expLabels3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
NDArrayIndex.interval(0, 3)}, l3d2); NDArrayIndex.interval(0, 3)}, l3d2);
INDArray expLM3d = Nd4j.create(3, 3, 4); INDArray expLM3d = Nd4j.create(3, 3, 4);
expLM3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, expLM3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
lm3d1); lm3d1);
expLM3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(), expLM3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
NDArrayIndex.interval(0, 3)}, lm3d2); NDArrayIndex.interval(0, 3)}, lm3d2);
DataSet merged3d = DataSet.merge(Arrays.asList(ds3d1, ds3d2)); DataSet merged3d = DataSet.merge(Arrays.asList(ds3d1, ds3d2));
@ -752,52 +754,52 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testShuffleNd(Nd4jBackend backend) { public void testShuffleNd(Nd4jBackend backend) {
int numDims = 7; int numDims = 7;
int nLabels = 3; int nLabels = 3;
Random r = new Random(); Random r = new Random();
int[] shape = new int[numDims]; int[] shape = new int[numDims];
int entries = 1; int entries = 1;
for (int i = 0; i < numDims; i++) { for (int i = 0; i < numDims; i++) {
//randomly generating shapes bigger than 1 //randomly generating shapes bigger than 1
shape[i] = r.nextInt(4) + 2; shape[i] = r.nextInt(4) + 2;
entries *= shape[i]; entries *= shape[i];
} }
int labels = shape[0] * nLabels; int labels = shape[0] * nLabels;
INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape); INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels); INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
DataSet ds = new DataSet(ds_data, ds_labels); DataSet ds = new DataSet(ds_data, ds_labels);
ds.shuffle(); ds.shuffle();
//Checking Nd dataset which is the data //Checking Nd dataset which is the data
for (int dim = 1; dim < numDims; dim++) { for (int dim = 1; dim < numDims; dim++) {
//get tensor along dimension - the order in every dimension but zero should be preserved
for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
//the difference between consecutive elements should be equal to the stride
for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
int f_element_diff = f_next_element - f_element;
assertEquals(f_element_diff, ds_data.stride(dim));
}
}
}
//Checking 2d, features
int dim = 1;
//get tensor along dimension - the order in every dimension but zero should be preserved //get tensor along dimension - the order in every dimension but zero should be preserved
for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) { for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
//the difference between consecutive elements should be equal to the stride //the difference between consecutive elements should be equal to the stride
for (int i = 0, j = 1; j < nLabels; i++, j++) { for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i); int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j); int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
int l_element_diff = l_next_element - l_element; int f_element_diff = f_next_element - f_element;
assertEquals(l_element_diff, ds_labels.stride(dim)); assertEquals(f_element_diff, ds_data.stride(dim));
} }
} }
}
//Checking 2d, features
int dim = 1;
//get tensor along dimension - the order in every dimension but zero should be preserved
for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
//the difference between consecutive elements should be equal to the stride
for (int i = 0, j = 1; j < nLabels; i++, j++) {
int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
int l_element_diff = l_next_element - l_element;
assertEquals(l_element_diff, ds_labels.stride(dim));
}
}
} }
@ParameterizedTest @ParameterizedTest
@ -936,9 +938,9 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
//Checking if the features and labels are equal //Checking if the features and labels are equal
assertEquals(iDataSet.getFeatures(), assertEquals(iDataSet.getFeatures(),
dsList.get(i).getFeatures().get(all(), all(), interval(0, minTSLength + i))); dsList.get(i).getFeatures().get(all(), all(), interval(0, minTSLength + i)));
assertEquals(iDataSet.getLabels(), assertEquals(iDataSet.getLabels(),
dsList.get(i).getLabels().get(all(), all(), interval(0, minTSLength + i))); dsList.get(i).getLabels().get(all(), all(), interval(0, minTSLength + i)));
} }
} }
@ -964,8 +966,8 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
for (boolean lMask : b) { for (boolean lMask : b) {
DataSet ds = new DataSet((features ? f : null), DataSet ds = new DataSet((features ? f : null),
(labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null), (labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null),
(lMask ? lm : null)); (lMask ? lm : null));
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos); DataOutputStream dos = new DataOutputStream(baos);
@ -1009,7 +1011,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
boolean lMask = true; boolean lMask = true;
DataSet ds = new DataSet((features ? f : null), (labels ? (labelsSameAsFeatures ? f : l) : null), DataSet ds = new DataSet((features ? f : null), (labels ? (labelsSameAsFeatures ? f : l) : null),
(fMask ? fm : null), (lMask ? lm : null)); (fMask ? fm : null), (lMask ? lm : null));
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos); DataOutputStream dos = new DataOutputStream(baos);
@ -1098,7 +1100,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend backend) throws IOException { public void testDataSetMetaDataSerialization(Nd4jBackend backend) throws IOException {
for(boolean withMeta : new boolean[]{false, true}) { for(boolean withMeta : new boolean[]{false, true}) {
// create simple data set with meta data object // create simple data set with meta data object
@ -1129,7 +1131,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMultiDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend nd4jBackend) throws IOException { public void testMultiDataSetMetaDataSerialization(Nd4jBackend nd4jBackend) throws IOException {
for(boolean withMeta : new boolean[]{false, true}) { for(boolean withMeta : new boolean[]{false, true}) {
// create simple data set with meta data object // create simple data set with meta data object

View File

@ -106,7 +106,8 @@ public class KFoldIteratorTest extends BaseNd4jTestWithBackends {
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void checkCornerCaseException(Nd4jBackend backend) { public void checkCornerCaseException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1), DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1),

View File

@ -21,27 +21,25 @@
package org.nd4j.linalg.dataset; package org.nd4j.linalg.dataset;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.nio.file.Path; import java.nio.file.Path;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTestWithBackends { public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMiniBatches(@TempDir Path testDir) throws Exception { public void testMiniBatches(Nd4jBackend backend) throws Exception {
DataSet load = new IrisDataSetIterator(150, 150).next(); DataSet load = new IrisDataSetIterator(150, 150).next();
final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.toFile()); final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.toFile());
while (iter.hasNext()) while (iter.hasNext())

View File

@ -39,8 +39,7 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_preConditionsIsNull_expect_NullPointerException(Nd4jBackend backend) { public void when_preConditionsIsNull_expect_NullPointerException(Nd4jBackend backend) {
assertThrows(NullPointerException.class,() -> { assertThrows(NullPointerException.class,() -> {

View File

@ -41,8 +41,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
return 'c'; return 'c';
} }
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_originalHeightIsZero_expect_IllegalArgumentException(Nd4jBackend backend) { public void when_originalHeightIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
@ -51,8 +50,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
}); });
} }
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_originalWidthIsZero_expect_IllegalArgumentException(Nd4jBackend backend) { public void when_originalWidthIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
@ -61,8 +59,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
}); });
} }
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_yStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) { public void when_yStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
@ -71,8 +68,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
}); });
} }
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_xStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) { public void when_xStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
@ -81,8 +77,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
}); });
} }
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) { public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
@ -91,8 +86,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
}); });
} }
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) { public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
@ -101,8 +95,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
}); });
} }
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) { public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
@ -111,8 +104,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
}); });
} }
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) { public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
// Assemble // Assemble

View File

@ -39,7 +39,8 @@ public class PermuteDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
return 'c'; return 'c';
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) { public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
assertThrows(NullPointerException.class,() -> { assertThrows(NullPointerException.class,() -> {
// Assemble // Assemble

View File

@ -20,7 +20,6 @@
package org.nd4j.linalg.dataset.api.preprocessor; package org.nd4j.linalg.dataset.api.preprocessor;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends; import org.nd4j.linalg.BaseNd4jTestWithBackends;
@ -39,7 +38,8 @@ public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTestWithBacke
return 'c'; return 'c';
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) { public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
assertThrows(NullPointerException.class,() -> { assertThrows(NullPointerException.class,() -> {
// Assemble // Assemble

View File

@ -139,7 +139,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
INDArray actualResult = data.mean(0); INDArray actualResult = data.mean(0);
INDArray expectedResult = Nd4j.create(new double[] {3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., INDArray expectedResult = Nd4j.create(new double[] {3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6.,
6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6.}, new int[] {2, 4, 4}); 6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6.}, new int[] {2, 4, 4});
assertEquals(expectedResult, actualResult,getFailureMessage()); assertEquals(expectedResult, actualResult,getFailureMessage(backend));
} }
@ -154,7 +154,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
INDArray actualResult = data.var(false, 0); INDArray actualResult = data.var(false, 0);
INDArray expectedResult = Nd4j.create(new double[] {1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., INDArray expectedResult = Nd4j.create(new double[] {1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4.,
4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4.}, new long[] {2, 4, 4}); 4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4.}, new long[] {2, 4, 4});
assertEquals(expectedResult, actualResult,getFailureMessage()); assertEquals(expectedResult, actualResult,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest

View File

@ -83,8 +83,7 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
} }
} }
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAccessException_1(Nd4jBackend backend) { public void testAccessException_1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {
@ -96,8 +95,7 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
} }
@Test() @ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAccessException_2(Nd4jBackend backend) { public void testAccessException_2(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> { assertThrows(IllegalStateException.class,() -> {

View File

@ -384,7 +384,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
assertEquals(exp, arrayZ); assertEquals(exp, arrayZ);
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTypesValidation_1(Nd4jBackend backend) { public void testTypesValidation_1(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> { assertThrows(IllegalArgumentException.class,() -> {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG); val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG);
@ -397,7 +399,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTypesValidation_2(Nd4jBackend backend) { public void testTypesValidation_2(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> { assertThrows(RuntimeException.class,() -> {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
@ -412,7 +415,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
} }
@Test() @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTypesValidation_3(Nd4jBackend backend) { public void testTypesValidation_3(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> { assertThrows(RuntimeException.class,() -> {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
@ -422,6 +426,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
} }
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTypesValidation_4(Nd4jBackend backend) { public void testTypesValidation_4(Nd4jBackend backend) {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.DOUBLE); val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.DOUBLE);
@ -485,7 +491,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBoolFloatCast2(){ public void testBoolFloatCast2(Nd4jBackend backend){
val first = Nd4j.zeros(DataType.FLOAT, 3, 5000); val first = Nd4j.zeros(DataType.FLOAT, 3, 5000);
INDArray asBool = first.castTo(DataType.BOOL); INDArray asBool = first.castTo(DataType.BOOL);
INDArray not = Transforms.not(asBool); // INDArray not = Transforms.not(asBool); //
@ -516,7 +522,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAssignScalarSimple(){ public void testAssignScalarSimple(Nd4jBackend backend){
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
INDArray arr = Nd4j.scalar(dt, 10.0); INDArray arr = Nd4j.scalar(dt, 10.0);
arr.assign(2.0); arr.assign(2.0);
@ -526,7 +532,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimple(){ public void testSimple(Nd4jBackend backend){
Nd4j.create(1); Nd4j.create(1);
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) { for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) {
// System.out.println("----- " + dt + " -----"); // System.out.println("----- " + dt + " -----");
@ -551,7 +557,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testWorkspaceBool(){ public void testWorkspaceBool(Nd4jBackend backend){
val conf = WorkspaceConfiguration.builder().minSize(10 * 1024 * 1024) val conf = WorkspaceConfiguration.builder().minSize(10 * 1024 * 1024)
.overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE) .overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE)
.policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL) .policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL)
@ -559,7 +565,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS"); val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS");
for( int i=0; i<10; i++ ) { for( int i = 0; i < 10; i++ ) {
try (val workspace = (Nd4jWorkspace)ws.notifyScopeEntered() ) { try (val workspace = (Nd4jWorkspace)ws.notifyScopeEntered() ) {
val bool = Nd4j.create(DataType.BOOL, 1, 10); val bool = Nd4j.create(DataType.BOOL, 1, 10);
val dbl = Nd4j.create(DataType.DOUBLE, 1, 10); val dbl = Nd4j.create(DataType.DOUBLE, 1, 10);
@ -574,8 +580,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
} }
} }
@Test @ParameterizedTest
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testArrayCreationFromPointer(Nd4jBackend backend) { public void testArrayCreationFromPointer(Nd4jBackend backend) {
val source = Nd4j.create(new double[]{1, 2, 3, 4, 5}); val source = Nd4j.create(new double[]{1, 2, 3, 4, 5});

View File

@ -40,13 +40,13 @@ public class NativeBlasTests extends BaseNd4jTestWithBackends {
@BeforeEach @BeforeEach
public void setUp(Nd4jBackend backend) { public void setUp() {
Nd4j.getExecutioner().enableDebugMode(true); Nd4j.getExecutioner().enableDebugMode(true);
Nd4j.getExecutioner().enableVerboseMode(true); Nd4j.getExecutioner().enableVerboseMode(true);
} }
@AfterEach @AfterEach
public void setDown(Nd4jBackend backend) { public void setDown() {
Nd4j.getExecutioner().enableDebugMode(false); Nd4j.getExecutioner().enableDebugMode(false);
Nd4j.getExecutioner().enableVerboseMode(false); Nd4j.getExecutioner().enableVerboseMode(false);
} }

View File

@ -77,18 +77,18 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
double sim = Transforms.cosineSim(vec1, vec2); double sim = Transforms.cosineSim(vec1, vec2);
assertEquals( 1, sim, 1e-1,getFailureMessage()); assertEquals( 1, sim, 1e-1,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCosineDistance(){ public void testCosineDistance(Nd4jBackend backend){
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3}); INDArray vec1 = Nd4j.create(new float[] {1, 2, 3});
INDArray vec2 = Nd4j.create(new float[] {3, 5, 7}); INDArray vec2 = Nd4j.create(new float[] {3, 5, 7});
// 1-17*sqrt(2/581) // 1-17*sqrt(2/581)
double distance = Transforms.cosineDistance(vec1, vec2); double distance = Transforms.cosineDistance(vec1, vec2);
assertEquals(0.0025851, distance, 1e-7,getFailureMessage()); assertEquals(0.0025851, distance, 1e-7,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@ -97,7 +97,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray arr = Nd4j.create(new double[] {55, 55}); INDArray arr = Nd4j.create(new double[] {55, 55});
INDArray arr2 = Nd4j.create(new double[] {60, 60}); INDArray arr2 = Nd4j.create(new double[] {60, 60});
double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).z().getDouble(0); double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).z().getDouble(0);
assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage()); assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@ -137,7 +137,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi(); INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi();
INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6); INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6);
Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1)); Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1));
assertEquals(scalarMax, postMax,getFailureMessage()); assertEquals(scalarMax, postMax,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@ -147,14 +147,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1)); Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1));
for (int i = 0; i < linspace.length(); i++) { for (int i = 0; i < linspace.length(); i++) {
double val = linspace.getDouble(i); double val = linspace.getDouble(i);
assertTrue( val >= 0 && val <= 1,getFailureMessage()); assertTrue( val >= 0 && val <= 1,getFailureMessage(backend));
} }
INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4)); Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4));
for (int i = 0; i < linspace2.length(); i++) { for (int i = 0; i < linspace2.length(); i++) {
double val = linspace2.getDouble(i); double val = linspace2.getDouble(i);
assertTrue( val >= 2 && val <= 4,getFailureMessage()); assertTrue( val >= 2 && val <= 4,getFailureMessage(backend));
} }
} }
@ -163,7 +163,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
public void testNormMax(Nd4jBackend backend) { public void testNormMax(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).z().getDouble(0); double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).z().getDouble(0);
assertEquals(4, normMax, 1e-1,getFailureMessage()); assertEquals(4, normMax, 1e-1,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@ -187,7 +187,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
public void testNorm2(Nd4jBackend backend) { public void testNorm2(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).z().getDouble(0); double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).z().getDouble(0);
assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage()); assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@ -198,7 +198,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray xDup = x.dup(); INDArray xDup = x.dup();
INDArray solution = Nd4j.valueArrayOf(5, 2.0); INDArray solution = Nd4j.valueArrayOf(5, 2.0);
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x})); opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
assertEquals(solution, x,getFailureMessage()); assertEquals(solution, x,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@ -221,13 +221,13 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray xDup = x.dup(); INDArray xDup = x.dup();
INDArray solution = Nd4j.valueArrayOf(5, 2.0); INDArray solution = Nd4j.valueArrayOf(5, 2.0);
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x})); opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
assertEquals(solution, x,getFailureMessage()); assertEquals(solution, x,getFailureMessage(backend));
Sum acc = new Sum(x.dup()); Sum acc = new Sum(x.dup());
opExecutioner.exec(acc); opExecutioner.exec(acc);
assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
Prod prod = new Prod(x.dup()); Prod prod = new Prod(x.dup());
opExecutioner.exec(prod); opExecutioner.exec(prod);
assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
} }
@ -275,7 +275,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Variance variance = new Variance(x.dup(), true); Variance variance = new Variance(x.dup(), true);
opExecutioner.exec(variance); opExecutioner.exec(variance);
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
} }
@ -284,14 +284,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIamax(Nd4jBackend backend) { public void testIamax(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage()); assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIamax2(Nd4jBackend backend) { public void testIamax2(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage()); assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
val op = new ArgAmax(linspace); val op = new ArgAmax(linspace);
int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0); int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0);
@ -307,11 +307,11 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Mean mean = new Mean(x); Mean mean = new Mean(x);
opExecutioner.exec(mean); opExecutioner.exec(mean);
assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
Variance variance = new Variance(x.dup(), true); Variance variance = new Variance(x.dup(), true);
opExecutioner.exec(variance); opExecutioner.exec(variance);
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage()); assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@ -321,7 +321,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
val softMax = new SoftMax(arr); val softMax = new SoftMax(arr);
opExecutioner.exec((CustomOp) softMax); opExecutioner.exec((CustomOp) softMax);
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage()); assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
} }
@ -332,7 +332,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Pow pow = new Pow(oneThroughSix, 2); Pow pow = new Pow(oneThroughSix, 2);
Nd4j.getExecutioner().exec(pow); Nd4j.getExecutioner().exec(pow);
INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36}); INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36});
assertEquals(answer, pow.z(),getFailureMessage()); assertEquals(answer, pow.z(),getFailureMessage(backend));
} }
@ -384,7 +384,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Log log = new Log(slice); Log log = new Log(slice);
opExecutioner.exec(log); opExecutioner.exec(log);
INDArray assertion = Nd4j.create(new double[] {0., 1.09861229, 1.60943791}); INDArray assertion = Nd4j.create(new double[] {0., 1.09861229, 1.60943791});
assertEquals(assertion, slice,getFailureMessage()); assertEquals(assertion, slice,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@ -572,7 +572,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
expected[i] = (float) Math.exp(slice.getDouble(i)); expected[i] = (float) Math.exp(slice.getDouble(i));
Exp exp = new Exp(slice); Exp exp = new Exp(slice);
opExecutioner.exec(exp); opExecutioner.exec(exp);
assertEquals( Nd4j.create(expected), slice,getFailureMessage()); assertEquals( Nd4j.create(expected), slice,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest
@ -582,7 +582,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
val softMax = new SoftMax(arr); val softMax = new SoftMax(arr);
opExecutioner.exec((CustomOp) softMax); opExecutioner.exec((CustomOp) softMax);
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage()); assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
} }
@ParameterizedTest @ParameterizedTest

Some files were not shown because too many files have changed in this diff Show More