Unify nd4j test profiles, get rid of old modules, fix more parameter issues with junit 5 tests

master
agibsonccc 2021-03-18 10:58:50 +09:00
parent e0077c38a9
commit ad4f47096c
135 changed files with 789 additions and 1911 deletions

View File

@ -31,7 +31,7 @@ jobs:
protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
windows-x86_64:
runs-on: windows-2019
@ -44,7 +44,7 @@ jobs:
run: |
set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
@ -60,5 +60,5 @@ jobs:
run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test

View File

@ -31,7 +31,7 @@ jobs:
protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
windows-x86_64:
runs-on: windows-2019
@ -44,7 +44,7 @@ jobs:
run: |
set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
@ -60,5 +60,5 @@ jobs:
run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test

View File

@ -34,5 +34,5 @@ jobs:
cmake --version
protoc --version
export OMP_NUM_THREADS=1
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptest-nd4j-native --also-make clean test
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Pnd4j-tests-cpu --also-make clean test

View File

@ -109,10 +109,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -30,6 +30,8 @@ import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.loader.FileBatch;
import java.io.File;
@ -40,13 +42,16 @@ import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.linalg.factory.Nd4jBackend;
@DisplayName("File Batch Record Reader Test")
class FileBatchRecordReaderTest extends BaseND4JTest {
public class FileBatchRecordReaderTest extends BaseND4JTest {
@TempDir Path testDir;
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Csv")
void testCsv(@TempDir Path testDir) throws Exception {
void testCsv(Nd4jBackend backend) throws Exception {
// This is an unrealistic use case - one line/record per CSV
File baseDir = testDir.toFile();
List<File> fileList = new ArrayList<>();
@ -75,9 +80,10 @@ class FileBatchRecordReaderTest extends BaseND4JTest {
}
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Csv Sequence")
void testCsvSequence(@TempDir Path testDir) throws Exception {
void testCsvSequence(Nd4jBackend backend) throws Exception {
// CSV sequence - 3 lines per file, 10 files
File baseDir = testDir.toFile();
List<File> fileList = new ArrayList<>();

View File

@ -21,7 +21,6 @@ package org.datavec.api.transform.ops;
import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;

View File

@ -60,10 +60,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -119,10 +119,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -59,10 +59,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -57,10 +57,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -65,10 +65,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -61,25 +61,18 @@
<artifactId>nd4j-common</artifactId>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-geo</artifactId>
<groupId>org.nd4j</groupId>
<artifactId>python4j-numpy</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-python</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -29,7 +29,6 @@ import org.datavec.api.transform.reduce.Reducer;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.*;
import org.datavec.python.PythonTransform;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
@ -39,7 +38,6 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout;
@ -166,37 +164,8 @@ class ExecutionTest {
List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out);
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
assertEquals(expOut, out);
}
@Test
@Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
@DisplayName("Test Python Execution Ndarray")
void testPythonExecutionNdarray() {
assertTimeout(ofMillis(60000), () -> {
Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
List<List<Writable>> functions = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>();
INDArray firstArr = Nd4j.linspace(1, 4, 4);
INDArray secondArr = Nd4j.linspace(1, 4, 4);
firstRow.add(new NDArrayWritable(firstArr));
firstRow.add(new NDArrayWritable(secondArr));
functions.add(firstRow);
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray expected = Transforms.sin(firstArr);
INDArray secondExpected = Transforms.cos(secondArr);
assertEquals(expected, firstResult);
assertEquals(secondExpected, secondResult);
});
}
}

View File

@ -1,155 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.local.transforms.transform;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.geo.LocationType;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author saudet
*/
public class TestGeoTransforms {
@BeforeAll
public static void beforeClass() throws Exception {
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
}
@AfterClass
public static void afterClass(){
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
}
@Test
public void testCoordinatesDistanceTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
.build();
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(4, out.numColumns());
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
out.getColumnTypes());
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
new DoubleWritable(Math.sqrt(160))),
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
new Text("10|5"))));
}
@Test
public void testIPAddressToCoordinatesTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("column").build();
Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
String in = "81.2.69.160";
double latitude = 51.5142;
double longitude = -0.0931;
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
assertEquals(2, coordinates.length);
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
//Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(transform);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ObjectInputStream ois = new ObjectInputStream(bais);
Transform deserialized = (Transform) ois.readObject();
writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
//System.out.println(Arrays.toString(coordinates));
assertEquals(2, coordinates.length);
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
}
@Test
public void testIPAddressToLocationTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("column").build();
LocationType[] locationTypes = LocationType.values();
String in = "81.2.69.160";
String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
"51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
for (int i = 0; i < locationTypes.length; i++) {
LocationType locationType = locationTypes[i];
String location = locations[i];
Transform transform = new IPAddressToLocationTransform("column", locationType);
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
assertEquals(location, writables.get(0).toString());
//System.out.println(location);
}
}
}

View File

@ -1,386 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.local.transforms.transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.condition.Condition;
import org.datavec.api.transform.filter.ConditionFilter;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.schema.Schema;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.datavec.api.writable.*;
import org.datavec.python.PythonCondition;
import org.datavec.python.PythonTransform;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.datavec.api.transform.schema.Schema.Builder;
import static org.junit.jupiter.api.Assertions.*;
@NotThreadSafe
public class TestPythonTransformProcess {
@Test()
public void testStringConcat() throws Exception{
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnString("col1")
.addColumnString("col2");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnString("col3");
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build();
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
List<Writable> outputs = tp.execute(inputs);
assertEquals((outputs.get(0)).toString(), "Hello ");
assertEquals((outputs.get(1)).toString(), "World!");
assertEquals((outputs.get(2)).toString(), "Hello World!");
}
@Test()
@Timeout(60000L)
public void testMixedTypes() throws Exception {
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
.addColumnString("col3")
.addColumnDouble("col4");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnInteger("col5");
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.inputSchema(initialSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(new IntWritable(10),
new FloatWritable(3.5f),
new Text("5"),
new DoubleWritable(2.0)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
}
@Test()
@Timeout(60000L)
public void testNDArray() throws Exception {
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape);
INDArray expectedOutput = arr1.add(arr2);
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test()
@Timeout(60000L)
public void testNDArray2() throws Exception {
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape);
INDArray expectedOutput = arr1.add(arr2);
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test()
@Timeout(60000L)
public void testNDArrayMixed() throws Exception{
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test()
@Timeout(60000L)
public void testPythonFilter() {
Schema schema = new Builder().addColumnInteger("column").build();
Condition condition = new PythonCondition(
"f = lambda: column < 0"
);
condition.setInputSchema(schema);
Filter filter = new ConditionFilter(condition);
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
}
@Test()
@Timeout(60000L)
public void testPythonFilterAndTransform() throws Exception {
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
.addColumnString("col3")
.addColumnDouble("col4");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnString("col6");
Schema finalSchema = schemaBuilder.build();
Condition condition = new PythonCondition(
"f = lambda: col1 < 0 and col2 > 10.0"
);
condition.setInputSchema(initialSchema);
Filter filter = new ConditionFilter(condition);
String pythonCode = "col6 = str(col1 + col2)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).filter(
filter
).build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(-3),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(11.2f),
new Text("abcd"),
new DoubleWritable(2.1))
);
LocalTransformExecutor.execute(inputs,tp);
}
@Test
public void testPythonTransformNoOutputSpecified() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
Schema inputSchema = new Builder()
.addColumnInteger("a")
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertEquals(3,execute.get(0).get(0).toInt());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testNumpyTransform() {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertNotNull(execute.get(0).get(1));
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testWithSetupRun() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n")
.returnAllInputs(true)
.setupAndRun(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.addColumnNDArray("b", new long[]{1, 1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
}
}

View File

@ -128,10 +128,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -92,6 +92,10 @@
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
</dependency>
<dependency>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
@ -154,7 +158,7 @@
<skip>${skipTestResourceEnforcement}</skip>
<rules>
<requireActiveProfile>
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles>
<profiles>nd4j-tests-cpu,nd4j-tests-cuda</profiles>
<all>false</all>
</requireActiveProfile>
</rules>
@ -163,23 +167,6 @@
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<argLine></argLine>
<!--
By default: Surefire will set the classpath based on the manifest. Because tests are not included
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
function correctly without this configuration.
For example, tests for custom transforms (where the custom transform is defined in the test directory)
will fail due to the custom transform not being found on the classpath.
http://maven.apache.org/surefire/maven-surefire-plugin/examples/class-loading.html
-->
<useSystemClassLoader>true</useSystemClassLoader>
<useManifestOnlyJar>false</useManifestOnlyJar>
</configuration>
</plugin>
<plugin>
<groupId>org.eclipse.m2e</groupId>
<artifactId>lifecycle-mapping</artifactId>
@ -249,7 +236,7 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
@ -266,7 +253,7 @@
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
@ -286,9 +273,6 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
</configuration>
</plugin>
</plugins>
</build>

View File

@ -64,7 +64,7 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
@ -75,7 +75,7 @@
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>

View File

@ -56,10 +56,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -166,7 +166,7 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
@ -177,7 +177,7 @@
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>

View File

@ -23,7 +23,6 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -34,7 +33,6 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Early Termination Data Set Iterator Test")
class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {

View File

@ -21,19 +21,16 @@ package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Early Termination Multi Data Set Iterator Test")

View File

@ -34,7 +34,6 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -46,7 +45,6 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@Disabled
@DisplayName("Attention Layer Test")

View File

@ -105,11 +105,12 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
<build>
<plugins>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<inherited>true</inherited>
<configuration>
<skip>true</skip>
</configuration>
@ -118,7 +119,7 @@
</build>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -56,10 +56,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -50,10 +50,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -45,10 +45,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -54,10 +54,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -112,7 +112,7 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
@ -123,7 +123,7 @@
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>

View File

@ -72,10 +72,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -306,7 +306,7 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
@ -317,7 +317,7 @@
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>

View File

@ -101,10 +101,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -49,10 +49,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -127,10 +127,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -102,7 +102,7 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
@ -113,7 +113,7 @@
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>

View File

@ -99,10 +99,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -44,10 +44,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -89,10 +89,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -88,10 +88,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -90,10 +90,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -105,10 +105,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -182,10 +182,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -77,10 +77,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -104,10 +104,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -141,10 +141,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -79,10 +79,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -426,10 +426,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -44,10 +44,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -87,10 +87,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -117,10 +117,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -143,6 +143,10 @@
</extensions>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-enforcer-plugin</artifactId>
@ -158,7 +162,7 @@
<skip>${skipBackendChoice}</skip>
<rules>
<requireActiveProfile>
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles>
<profiles>nd4j-tests-cpu,nd4j-tests-cuda</profiles>
<all>false</all>
</requireActiveProfile>
</rules>
@ -227,43 +231,6 @@
</plugin>
</plugins>
<pluginManagement>
<plugins>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<inherited>true</inherited>
<configuration>
<!--
By default: Surefire will set the classpath based on the manifest. Because tests are not included
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
function correctly without this configuration.
For example, tests for custom layers (where the custom layer is defined in the test directory)
will fail due to the custom layer not being found on the classpath.
http://maven.apache.org/surefire/maven-surefire-plugin/examples/class-loading.html
-->
<useSystemClassLoader>true</useSystemClassLoader>
<useManifestOnlyJar>false</useManifestOnlyJar>
<argLine> -Dfile.encoding=UTF-8 -Xmx8g "</argLine>
<includes>
<!-- Default setting only runs tests that start/end with "Test" -->
<include>*.java</include>
<include>**/*.java</include>
</includes>
</configuration>
<dependencies>
<dependency>
<groupId>org.apache.maven.surefire</groupId>
<artifactId>surefire-junit-platform</artifactId>
<version>${maven-surefire-plugin.version}</version>
</dependency>
</dependencies>
</plugin>
<plugin>
<groupId>org.eclipse.m2e</groupId>
<artifactId>lifecycle-mapping</artifactId>
</plugin>
</plugins>
</pluginManagement>
</build>
<profiles>
@ -290,10 +257,10 @@
<module>deeplearning4j-cuda</module>
</modules>
</profile>
<!-- For running unit tests with nd4j-native: "mvn clean test -P test-nd4j-native"
<!-- For running unit tests with nd4j-native: "mvn clean test -P nd4j-tests-cpu"
Note that this excludes DL4J-cuda -->
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
@ -311,70 +278,10 @@
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<inherited>true</inherited>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>${junit.version}</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<version>${junit.version}</version>
</dependency>
<dependency>
<groupId>org.apache.maven.surefire</groupId>
<artifactId>surefire-junit-platform</artifactId>
<version>${maven-surefire-plugin.version}</version>
</dependency>
</dependencies>
<configuration>
<environmentVariables>
</environmentVariables>
<testSourceDirectory>src/test/java</testSourceDirectory>
<includes>
<include>*.java</include>
<include>**/*.java</include>
<include>**/Test*.java</include>
<include>**/*Test.java</include>
<include>**/*TestCase.java</include>
</includes>
<junitArtifactName>org.junit.jupiter:junit-jupiter-engine</junitArtifactName>
<systemPropertyVariables>
<org.nd4j.linalg.defaultbackend>
org.nd4j.linalg.cpu.nativecpu.CpuBackend
</org.nd4j.linalg.defaultbackend>
<org.nd4j.linalg.tests.backendstorun>
org.nd4j.linalg.cpu.nativecpu.CpuBackend
</org.nd4j.linalg.tests.backendstorun>
</systemPropertyVariables>
<!--
Maximum heap size was set to 8g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough.
For testing large zoo models, this may not be enough (so comment it out).
-->
<argLine></argLine>
</configuration>
</plugin>
</plugins>
</build>
</profile>
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
@ -392,43 +299,6 @@
<scope>test</scope>
</dependency>
</dependencies>
<!-- Default to ALL modules here, unlike nd4j-native -->
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>${maven-surefire-plugin.version}</version>
<configuration>
<environmentVariables>
</environmentVariables>
<testSourceDirectory>src/test/java</testSourceDirectory>
<includes>
<include>*.java</include>
<include>**/*.java</include>
<include>**/Test*.java</include>
<include>**/*Test.java</include>
<include>**/*TestCase.java</include>
</includes>
<junitArtifactName>org.junit.jupiter:junit-jupiter</junitArtifactName>
<systemPropertyVariables>
<org.nd4j.linalg.defaultbackend>
org.nd4j.linalg.jcublas.JCublasBackend
</org.nd4j.linalg.defaultbackend>
<org.nd4j.linalg.tests.backendstorun>
org.nd4j.linalg.jcublas.JCublasBackend
</org.nd4j.linalg.tests.backendstorun>
</systemPropertyVariables>
<!--
Maximum heap size was set to 6g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough.
-->
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project>

View File

@ -5,7 +5,7 @@ Linux
[INFO] Total time: 14.610 s
[INFO] Finished at: 2021-03-06T15:35:28+09:00
[INFO] ------------------------------------------------------------------------
[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist.
[WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
[ERROR]
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
@ -749,7 +749,7 @@ make[1]: Leaving directory '/c/Users/agibs/Documents/GitHub/eclipse-deeplearning
[INFO] Total time: 15.482 s
[INFO] Finished at: 2021-03-06T15:27:35+09:00
[INFO] ------------------------------------------------------------------------
[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist.
[WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
[ERROR]
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.

View File

@ -0,0 +1,19 @@
Const,in_0
Const,while/Const
Const,while/add/y
Identity,in_0/read
Enter,while/Enter
Enter,while/Enter_1
Merge,while/Merge
Merge,while/Merge_1
Less,while/Less
LoopCond,while/LoopCond
Switch,while/Switch
Switch,while/Switch_1
Identity,while/Identity
Exit,while/Exit
Identity,while/Identity_1
Exit,while/Exit_1
Add,while/add
NextIteration,while/NextIteration_1
NextIteration,while/NextIteration

View File

@ -0,0 +1,16 @@
Identity,in_0/read
Enter,while/Enter
Enter,while/Enter_1
Merge,while/Merge
Merge,while/Merge_1
Less,while/Less
LoopCond,while/LoopCond
Switch,while/Switch
Switch,while/Switch_1
Identity,while/Identity
Exit,while/Exit
Identity,while/Identity_1
Exit,while/Exit_1
Add,while/add
NextIteration,while/NextIteration_1
NextIteration,while/NextIteration

View File

@ -0,0 +1,19 @@
in_0
while/Const
while/add/y
in_0/read
while/Enter
while/Enter_1
while/Merge
while/Merge_1
while/Less
while/LoopCond
while/Switch
while/Switch_1
while/Identity
while/Exit
while/Identity_1
while/Exit_1
while/add
while/NextIteration_1
while/NextIteration

View File

@ -303,7 +303,7 @@
For testing large zoo models, this may not be enough (so comment it out).
-->
<argLine>-Dfile.encoding=UTF-8 "</argLine>
<argLine>-Dfile.encoding=UTF-8 </argLine>
</configuration>
</plugin>
</plugins>

View File

@ -27,6 +27,7 @@ import java.util.List;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.autodiff.samediff.SDVariable;
@ -482,7 +483,7 @@ public class LayerOpValidation extends BaseOpValidation {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConv3d(Nd4jBackend backend) {
public void testConv3d(Nd4jBackend backend, TestInfo testInfo) {
//Pooling3d, Conv3D, batch norm
Nd4j.getRandom().setSeed(12345);
@ -573,7 +574,7 @@ public class LayerOpValidation extends BaseOpValidation {
tc.testName(msg);
String error = OpValidation.validate(tc);
if (error != null) {
failed.add(name);
failed.add(testInfo.getTestMethod().get().getName());
}
}
}
@ -1353,7 +1354,8 @@ public class LayerOpValidation extends BaseOpValidation {
assertNull(err, err);
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
int nIn = 3;
@ -1382,7 +1384,8 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345);
@ -1405,7 +1408,8 @@ public class LayerOpValidation extends BaseOpValidation {
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
Nd4j.getRandom().setSeed(12345);

View File

@ -22,6 +22,7 @@ package org.nd4j.autodiff.opvalidation;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@ -664,6 +665,7 @@ public class MiscOpValidation extends BaseOpValidation {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testMmulGradientManual(Nd4jBackend backend) {
SameDiff sameDiff = SameDiff.create();
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);

View File

@ -69,7 +69,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
@AfterEach
public void tearDown(Nd4jBackend backend) {
public void tearDown() {
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
}

View File

@ -28,6 +28,7 @@ import lombok.val;
import org.apache.commons.math3.linear.LUDecomposition;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.OpValidationSuite;
@ -83,7 +84,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConcat(Nd4jBackend backend) {
public void testConcat(Nd4jBackend backend, TestInfo testInfo) {
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
int[] concatDim = new int[]{0, 0, 0};
List<List<int[]>> origShapes = new ArrayList<>();
@ -115,7 +116,7 @@ public class ShapeOpValidation extends BaseOpValidation {
String error = OpValidation.validate(tc);
if(error != null){
failed.add(name);
failed.add(testInfo.getTestMethod().get().getName());
}
}
@ -285,7 +286,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSqueezeGradient(Nd4jBackend backend) {
public void testSqueezeGradient(Nd4jBackend backend,TestInfo testInfo) {
val origShape = new long[]{3, 4, 5};
List<String> failed = new ArrayList<>();
@ -339,7 +340,7 @@ public class ShapeOpValidation extends BaseOpValidation {
String error = OpValidation.validate(tc, true);
if(error != null){
failed.add(name);
failed.add(testInfo.getTestMethod().get().getName());
}
}
}
@ -580,8 +581,9 @@ public class ShapeOpValidation extends BaseOpValidation {
return Long.MAX_VALUE;
}
@Test()
public void testStack(Nd4jBackend backend) {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testStack(Nd4jBackend backend,TestInfo testInfo) {
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
@ -661,7 +663,7 @@ public class ShapeOpValidation extends BaseOpValidation {
String error = OpValidation.validate(tc);
if(error != null){
failed.add(name);
failed.add(testInfo.getTestMethod().get().getName());
}
}
}

View File

@ -72,6 +72,8 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override
public char ordering(){
@ -82,7 +84,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testBasic(Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create();
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
@ -121,7 +123,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
int numOutputs = fg.outputsLength();
List<IntPair> outputs = new ArrayList<>(numOutputs);
for( int i=0; i<numOutputs; i++ ){
for( int i = 0; i < numOutputs; i++) {
outputs.add(fg.outputs(i));
}
@ -138,7 +140,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testSimple(Nd4jBackend backend) throws Exception {
for( int i = 0; i < 10; i++ ) {
for(boolean execFirst : new boolean[]{false, true}) {
log.info("Starting test: i={}, execFirst={}", i, execFirst);
@ -268,7 +270,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testTrainingSerde(Nd4jBackend backend) throws Exception {
//Ensure 2 things:
//1. Training config is serialized/deserialized correctly

View File

@ -109,7 +109,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
}
@BeforeEach
public void before(Nd4jBackend backend) {
public void before() {
Nd4j.create(1);
initialType = Nd4j.dataType();
@ -118,7 +118,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
}
@AfterEach
public void after(Nd4jBackend backend) {
public void after() {
Nd4j.setDataType(initialType);
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);

View File

@ -21,9 +21,6 @@
package org.nd4j.autodiff.samediff.listeners;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@ -47,11 +44,11 @@ import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override
@ -97,7 +94,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testCheckpointEveryEpoch(Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();
@ -132,7 +129,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testCheckpointEvery5Iter(Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();
@ -172,7 +169,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testCheckpointListenerEveryTimeUnit(Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();
@ -217,7 +214,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testCheckpointListenerKeepLast3AndEvery3(Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
SameDiff sd = getModel();

View File

@ -23,6 +23,7 @@ package org.nd4j.autodiff.samediff.listeners;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
@ -49,6 +50,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override
public char ordering() {
@ -59,7 +61,8 @@ public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
@Disabled
public void testProfilingListenerSimple(Nd4jBackend backend) throws Exception {
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);

View File

@ -64,6 +64,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
public class FileReadWriteTests extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override
public char ordering(){
@ -81,7 +82,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
public void testSimple(Nd4jBackend backend) throws IOException {
SameDiff sd = SameDiff.create();
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
SDVariable sum = v.sum();
@ -163,7 +164,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
//Append a number of events
w.registerEventName("accuracy");
for( int iter=0; iter<3; iter++) {
for( int iter = 0; iter < 3; iter++) {
long t = System.currentTimeMillis();
w.writeScalarEvent("accuracy", LogFileWriter.EventSubtype.EVALUATION, t, iter, 0, 0.5 + 0.1 * iter);
}
@ -175,7 +176,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
UIAddName addName = (UIAddName) events.get(0).getRight();
assertEquals("accuracy", addName.name());
for( int i=1; i<4; i++ ){
for( int i = 1; i < 4; i++ ){
FlatArray fa = (FlatArray) events.get(i).getRight();
INDArray arr = Nd4j.createFromFlatArray(fa);
@ -186,7 +187,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
public void testNullBinLabels(Nd4jBackend backend) throws Exception{
File dir = testDir.toFile();
File f = new File(dir, "temp.bin");
LogFileWriter w = new LogFileWriter(f);

View File

@ -56,6 +56,8 @@ import static org.junit.jupiter.api.Assertions.*;
public class UIListenerTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override
public char ordering() {
return 'c';
@ -65,7 +67,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testUIListenerBasic(Nd4jBackend backend) throws Exception {
Nd4j.getRandom().setSeed(12345);
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
@ -102,7 +104,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testUIListenerContinue(Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet();
@ -194,7 +196,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testUIListenerBadContinue(Nd4jBackend backend) throws Exception {
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
SameDiff sd1 = getSimpleNet();
@ -275,7 +277,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
}
private static SameDiff getSimpleNet(){
private static SameDiff getSimpleNet() {
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);

View File

@ -22,6 +22,7 @@ package org.nd4j.evaluation;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@ -48,6 +49,7 @@ public class NewInstanceTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testNewInstances(Nd4jBackend backend) {
boolean print = true;
Nd4j.getRandom().setSeed(12345);

View File

@ -20,7 +20,7 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.evaluation.classification.ROC;
@ -42,14 +42,15 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class ROCBinaryTest extends BaseNd4jTestWithBackends {
@Override
public char ordering() {
return 'c';
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testROCBinary(Nd4jBackend backend) {
//Compare ROCBinary to ROC class
@ -144,7 +145,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBinaryMerging(Nd4jBackend backend) {
for (int nSteps : new int[]{30, 0}) { //0 == exact
@ -175,7 +176,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
@ -216,7 +217,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary3d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
@ -251,7 +252,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary4d(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
@ -286,7 +287,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary3dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
@ -348,7 +349,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCBinary4dMasking(Nd4jBackend backend) {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);

View File

@ -20,6 +20,7 @@
package org.nd4j.evaluation;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@ -82,16 +83,16 @@ public class ROCTest extends BaseNd4jTestWithBackends {
expFPR.put(10 / 10.0, 0.0 / totalNegatives);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBasic(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax)
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
INDArray actual = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
{0, 1}, {0, 1}});
{0, 1}, {0, 1}});
ROC roc = new ROC(10);
roc.eval(actual, predictions);
@ -126,15 +127,15 @@ public class ROCTest extends BaseNd4jTestWithBackends {
assertEquals(1.0, auc, 1e-6);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBasicSingleClass(Nd4jBackend backend) {
//1 output here - single probability value (sigmoid)
//add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
INDArray predictions =
Nd4j.create(new double[] {0.001, 0.101, 0.201, 0.301, 0.401, 0.501, 0.601, 0.701, 0.801, 0.901},
new int[] {10, 1});
Nd4j.create(new double[] {0.001, 0.101, 0.201, 0.301, 0.401, 0.501, 0.601, 0.701, 0.801, 0.901},
new int[] {10, 1});
INDArray actual = Nd4j.create(new double[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}, new int[] {10, 1});
@ -165,7 +166,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRoc(Nd4jBackend backend) {
//Previous tests allowed for a perfect classifier with right threshold...
@ -173,7 +174,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});
INDArray prediction = Nd4j.create(new double[][] {{0.199, 0.801}, {0.499, 0.501}, {0.399, 0.601},
{0.799, 0.201}, {0.899, 0.101}});
{0.799, 0.201}, {0.899, 0.101}});
Map<Double, Double> expTPR = new HashMap<>();
double totalPositives = 2.0;
@ -251,27 +252,27 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
//Same as first test...
//2 outputs here - probability distribution over classes (softmax)
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
{0, 1}, {0, 1}});
{0, 1}, {0, 1}});
INDArray predictions3d = Nd4j.create(2, 2, 5);
INDArray firstTSp =
predictions3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
predictions3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
assertArrayEquals(new long[] {5, 2}, firstTSp.shape());
firstTSp.assign(predictions2d.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
INDArray secondTSp =
predictions3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
predictions3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
assertArrayEquals(new long[] {5, 2}, secondTSp.shape());
secondTSp.assign(predictions2d.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
@ -299,23 +300,23 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
//2 outputs here - probability distribution over classes (softmax)
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
{0, 1}, {0, 1}});
{0, 1}, {0, 1}});
//Create time series data... first time series: length 4. Second time series: length 6
INDArray predictions3d = Nd4j.create(2, 2, 6);
INDArray tad = predictions3d.tensorAlongDimension(0, 1, 2).transpose();
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
.assign(predictions2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
.assign(predictions2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
tad = predictions3d.tensorAlongDimension(1, 1, 2).transpose();
tad.assign(predictions2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
@ -324,7 +325,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
INDArray labels3d = Nd4j.create(2, 2, 6);
tad = labels3d.tensorAlongDimension(0, 1, 2).transpose();
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
.assign(actual2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
.assign(actual2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
tad = labels3d.tensorAlongDimension(1, 1, 2).transpose();
tad.assign(actual2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
@ -350,7 +351,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
@ -381,7 +382,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCompare2Vs3Classes(Nd4jBackend backend) {
@ -431,7 +432,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMerging(Nd4jBackend backend) {
int nArrays = 10;
@ -477,7 +478,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMerging2(Nd4jBackend backend) {
int nArrays = 10;
@ -523,7 +524,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testROCMultiMerging(Nd4jBackend backend) {
@ -572,7 +573,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAUCPrecisionRecall(Nd4jBackend backend) {
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
@ -620,7 +621,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocAucExact(Nd4jBackend backend) {
@ -681,20 +682,20 @@ public class ROCTest extends BaseNd4jTestWithBackends {
*/
double[] p = new double[] {0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503, 0.5955447, 0.96451452,
0.6531771, 0.74890664, 0.65356987, 0.74771481, 0.96130674, 0.0083883, 0.10644438, 0.29870371,
0.65641118, 0.80981255, 0.87217591, 0.9646476, 0.72368535, 0.64247533, 0.71745362, 0.46759901,
0.32558468, 0.43964461, 0.72968908, 0.99401459, 0.67687371, 0.79082252, 0.17091426};
0.6531771, 0.74890664, 0.65356987, 0.74771481, 0.96130674, 0.0083883, 0.10644438, 0.29870371,
0.65641118, 0.80981255, 0.87217591, 0.9646476, 0.72368535, 0.64247533, 0.71745362, 0.46759901,
0.32558468, 0.43964461, 0.72968908, 0.99401459, 0.67687371, 0.79082252, 0.17091426};
double[] l = new double[] {1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
0, 1};
0, 1};
double[] fpr_skl = new double[] {0.0, 0.0, 0.15789474, 0.15789474, 0.31578947, 0.31578947, 0.52631579,
0.52631579, 0.68421053, 0.68421053, 0.84210526, 0.84210526, 0.89473684, 0.89473684, 1.0};
0.52631579, 0.68421053, 0.68421053, 0.84210526, 0.84210526, 0.89473684, 0.89473684, 1.0};
double[] tpr_skl = new double[] {0.0, 0.09090909, 0.09090909, 0.18181818, 0.18181818, 0.36363636, 0.36363636,
0.45454545, 0.45454545, 0.72727273, 0.72727273, 0.90909091, 0.90909091, 1.0, 1.0};
0.45454545, 0.45454545, 0.72727273, 0.72727273, 0.90909091, 0.90909091, 1.0, 1.0};
//Note the change to the last value: same TPR and FPR at 0.0083883 and 0.0 -> we add the 0.0 threshold edge case + combine with the previous one. Same result
double[] thr_skl = new double[] {1.0, 0.99401459, 0.96130674, 0.92961609, 0.79082252, 0.74771481, 0.67687371,
0.65641118, 0.64247533, 0.46759901, 0.31637555, 0.20456028, 0.18391881, 0.17091426, 0.0};
0.65641118, 0.64247533, 0.46759901, 0.31637555, 0.20456028, 0.18391881, 0.17091426, 0.0};
INDArray prob = Nd4j.create(p, new int[] {30, 1});
INDArray label = Nd4j.create(l, new int[] {30, 1});
@ -784,7 +785,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
@ -797,7 +798,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
double[] threshold = new double[101];
@ -814,15 +815,15 @@ public class ROCTest extends BaseNd4jTestWithBackends {
PrecisionRecallCurve prc = new PrecisionRecallCurve(threshold, precision, recall, null, null, null, -1);
PrecisionRecallCurve.Point[] points = new PrecisionRecallCurve.Point[] {
//Test exact:
prc.getPointAtThreshold(0.05), prc.getPointAtPrecision(0.05), prc.getPointAtRecall(1 - 0.05),
//Test exact:
prc.getPointAtThreshold(0.05), prc.getPointAtPrecision(0.05), prc.getPointAtRecall(1 - 0.05),
//Test approximate (point doesn't exist exactly). When it doesn't exist:
//Threshold: lowest threshold equal to or exceeding the specified threshold value
//Precision: lowest threshold equal to or exceeding the specified precision value
//Recall: highest threshold equal to or exceeding the specified recall value
prc.getPointAtThreshold(0.0495), prc.getPointAtPrecision(0.0495),
prc.getPointAtRecall(1 - 0.0505)};
//Test approximate (point doesn't exist exactly). When it doesn't exist:
//Threshold: lowest threshold equal to or exceeding the specified threshold value
//Precision: lowest threshold equal to or exceeding the specified precision value
//Recall: highest threshold equal to or exceeding the specified recall value
prc.getPointAtThreshold(0.0495), prc.getPointAtPrecision(0.0495),
prc.getPointAtRecall(1 - 0.0505)};
@ -834,7 +835,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
//Sanity check: values calculated from the confusion matrix should match the PR curve values
@ -843,7 +844,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
ROC r = new ROC(0, removeRedundantPts);
INDArray labels = Nd4j.getExecutioner()
.exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE,100, 1), 0.5));
.exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE,100, 1), 0.5));
INDArray probs = Nd4j.rand(100, 1);
r.eval(labels, probs);
@ -874,7 +875,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocMerge(){
Nd4j.getRandom().setSeed(12345);
@ -919,7 +920,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
assertEquals(auprc, auprcAct, 1e-6);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocMultiMerge(){
Nd4j.getRandom().setSeed(12345);
@ -931,9 +932,9 @@ public class ROCTest extends BaseNd4jTestWithBackends {
int nOut = 5;
Random r = new Random(12345);
for( int i=0; i<10; i++ ){
for( int i = 0; i < 10; i++ ){
INDArray labels = Nd4j.zeros(3, nOut);
for( int j=0; j<3; j++ ){
for( int j = 0; j < 3; j++) {
labels.putScalar(j, r.nextInt(nOut), 1.0 );
}
INDArray out = Nd4j.rand(3, nOut);
@ -956,7 +957,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
roc1.merge(roc2);
for( int i=0; i<nOut; i++ ) {
for( int i = 0; i < nOut; i++) {
double aucExp = roc.calculateAUC(i);
double auprc = roc.calculateAUCPR(i);
@ -969,9 +970,10 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testRocBinaryMerge(){
@Disabled
public void testRocBinaryMerge(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
ROCBinary roc = new ROCBinary();
@ -980,7 +982,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
int nOut = 5;
for( int i=0; i<10; i++ ){
for( int i = 0; i < 10; i++) {
INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(3, nOut),0.5));
INDArray out = Nd4j.rand(3, nOut);
out.diviColumnVector(out.sum(1));
@ -1015,7 +1017,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentationBinary(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
@ -1106,7 +1108,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case

View File

@ -50,7 +50,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
return 'c';
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEvalParameters(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
int specCols = 5;

View File

@ -152,7 +152,7 @@ public class LoneTest extends BaseNd4jTestWithBackends {
public void maskWhenMerge(Nd4jBackend backend) {
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
List<DataSet> dataSetList = new ArrayList<DataSet>();
List<DataSet> dataSetList = new ArrayList<>();
dataSetList.add(dsA);
dataSetList.add(dsB);
DataSet fullDataSet = DataSet.merge(dataSetList);
@ -175,7 +175,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
// System.out.println(b);
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
//broken at a threshold
public void testArgMax(Nd4jBackend backend) {
int max = 63;
@ -263,7 +264,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
// log.info("p50: {}; avg: {};", times.get(times.size() / 2), time);
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void checkIllegalElementOps(Nd4jBackend backend) {
assertThrows(Exception.class,() -> {
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
@ -328,13 +330,13 @@ public class LoneTest extends BaseNd4jTestWithBackends {
reshaped.getDouble(i);
}
for (int j=0;j<arr.slices();j++) {
for (int k=0;k<arr.slice(j).length();k++) {
for (int k = 0; k < arr.slice(j).length(); k++) {
// log.info("\nArr: slice " + j + " element " + k + " " + arr.slice(j).getDouble(k));
arr.slice(j).getDouble(k);
}
}
for (int j=0;j<reshaped.slices();j++) {
for (int k=0;k<reshaped.slice(j).length();k++) {
for (int j = 0;j < reshaped.slices(); j++) {
for (int k = 0;k < reshaped.slice(j).length(); k++) {
// log.info("\nReshaped: slice " + j + " element " + k + " " + reshaped.slice(j).getDouble(k));
reshaped.slice(j).getDouble(k);
}

View File

@ -245,7 +245,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, false);
assertEquals(sorted[1], sorted2);
INDArray shouldIndex = Nd4j.create(new double[] {1, 1, 0, 0}, new long[] {2, 2});
assertEquals(shouldIndex, sorted[0],getFailureMessage());
assertEquals(shouldIndex, sorted[0],getFailureMessage(backend));
}
@ParameterizedTest
@ -266,7 +266,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, true);
assertEquals(sorted[1], sorted2);
INDArray shouldIndex = Nd4j.create(new double[] {0, 0, 1, 1}, new long[] {2, 2});
assertEquals(shouldIndex, sorted[0],getFailureMessage());
assertEquals(shouldIndex, sorted[0],getFailureMessage(backend));
}
@ParameterizedTest
@ -328,13 +328,13 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
public void testDivide(Nd4jBackend backend) {
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
INDArray div = two.div(two);
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage());
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage(backend));
INDArray half = Nd4j.create(new float[] {0.5f, 0.5f, 0.5f, 0.5f}, new long[] {2, 2});
INDArray divi = Nd4j.create(new float[] {0.3f, 0.6f, 0.9f, 0.1f}, new long[] {2, 2});
INDArray assertion = Nd4j.create(new float[] {1.6666666f, 0.8333333f, 0.5555556f, 5}, new long[] {2, 2});
INDArray result = half.div(divi);
assertEquals( assertion, result,getFailureMessage());
assertEquals( assertion, result,getFailureMessage(backend));
}
@ -344,7 +344,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
INDArray sigmoid = Transforms.sigmoid(n, false);
assertEquals( assertion, sigmoid,getFailureMessage());
assertEquals( assertion, sigmoid,getFailureMessage(backend));
}
@ -354,7 +354,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
INDArray neg = Transforms.neg(n);
assertEquals(assertion, neg,getFailureMessage());
assertEquals(assertion, neg,getFailureMessage(backend));
}
@ -365,12 +365,12 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
double sim = Transforms.cosineSim(vec1, vec2);
assertEquals(1, sim, 1e-1,getFailureMessage());
assertEquals(1, sim, 1e-1,getFailureMessage(backend));
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
sim = Transforms.cosineSim(vec3, vec4);
assertEquals(0.98, sim, 1e-1,getFailureMessage());
assertEquals(0.98, sim, 1e-1,getFailureMessage(backend));
}
@ -621,7 +621,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray innerProduct = n.mmul(transposed);
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
assertEquals(scalar, innerProduct,getFailureMessage());
assertEquals(scalar, innerProduct,getFailureMessage(backend));
}
@ -678,7 +678,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray five = Nd4j.ones(5);
five.addi(five.dup());
INDArray twos = Nd4j.valueArrayOf(5, 2);
assertEquals(twos, five,getFailureMessage());
assertEquals(twos, five,getFailureMessage(backend));
}
@ -692,7 +692,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
INDArray test = arr.mmul(arr.transpose());
assertEquals(assertion, test,getFailureMessage());
assertEquals(assertion, test,getFailureMessage(backend));
}
@ -704,7 +704,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
Nd4j.exec(new PrintVariable(newSlice));
log.info("Slice: {}", newSlice);
n.putSlice(0, newSlice);
assertEquals( newSlice, n.slice(0),getFailureMessage());
assertEquals( newSlice, n.slice(0),getFailureMessage(backend));
}
@ -713,7 +713,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
linear.putScalar(new long[] {0, 1}, 1);
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage());
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage(backend));
}
@ -1059,7 +1059,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
INDArray nClone = n1.add(n2);
assertEquals(Nd4j.scalar(3), nClone);
INDArray n1PlusN2 = n1.add(n2);
assertFalse(n1PlusN2.equals(n1),getFailureMessage());
assertFalse(n1PlusN2.equals(n1),getFailureMessage(backend));
INDArray n3 = Nd4j.scalar(3.0);
INDArray n4 = Nd4j.scalar(4.0);

View File

@ -156,7 +156,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
DataType initialType = Nd4j.dataType();
Level1 l1 = Nd4j.getBlasWrapper().level1();
@TempDir Path testDir;
@Override
public long getTimeoutMilliseconds() {
@ -255,7 +255,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSerialization(@TempDir Path testDir) throws Exception {
public void testSerialization(Nd4jBackend backend) throws Exception {
Nd4j.getRandom().setSeed(12345);
INDArray arr = Nd4j.rand(1, 20);
@ -339,7 +339,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertArrayEquals(assertion,shapeTest);
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled //temporary till libnd4j implements general broadcasting
public void testAutoBroadcastAdd(Nd4jBackend backend) {
INDArray left = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,1,2,1);
@ -370,9 +371,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
n.divi(Nd4j.scalar(1.0d));
n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
assertEquals(27, n.sumNumber().doubleValue(), 1e-1,getFailureMessage());
assertEquals(27, n.sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
INDArray a = n.slice(2);
assertEquals( true, Arrays.equals(new long[] {3, 3}, a.shape()),getFailureMessage());
assertEquals( true, Arrays.equals(new long[] {3, 3}, a.shape()),getFailureMessage(backend));
}
@ -478,12 +479,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
INDArray test = arr.mmul(arr.transpose());
assertEquals(assertion, test,getFailureMessage());
assertEquals(assertion, test,getFailureMessage(backend));
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testMmulOp() throws Exception {
public void testMmulOp(Nd4jBackend backend) throws Exception {
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
INDArray z = Nd4j.create(2, 2);
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
@ -494,7 +496,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
DynamicCustomOp op = new Mmul(arr, arr, z, mMulTranspose);
Nd4j.getExecutioner().execAndReturn(op);
assertEquals(assertion, z,getFailureMessage());
assertEquals(assertion, z,getFailureMessage(backend));
}
@ -505,7 +507,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray row1 = oneThroughFour.getRow(1).dup();
oneThroughFour.subiRowVector(row1);
INDArray result = Nd4j.create(new double[] {-2, -2, 0, 0}, new long[] {2, 2});
assertEquals(result, oneThroughFour,getFailureMessage());
assertEquals(result, oneThroughFour,getFailureMessage(backend));
}
@ -1093,7 +1095,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertTrue(expAllOnes.all());
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testSumAlongDim1sEdgeCases(Nd4jBackend backend) {
val shapes = new long[][] {
@ -1227,7 +1230,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray row1 = oneThroughFour.getRow(1);
row1.addi(1);
INDArray result = Nd4j.create(new double[] {1, 2, 4, 5}, new long[] {2, 2});
assertEquals(result, oneThroughFour,getFailureMessage());
assertEquals(result, oneThroughFour,getFailureMessage(backend));
}
@ -1241,8 +1244,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray linear = test.reshape(-1);
linear.putScalar(2, 6);
linear.putScalar(3, 7);
assertEquals(6, linear.getFloat(2), 1e-1,getFailureMessage());
assertEquals(7, linear.getFloat(3), 1e-1,getFailureMessage());
assertEquals(6, linear.getFloat(2), 1e-1,getFailureMessage(backend));
assertEquals(7, linear.getFloat(3), 1e-1,getFailureMessage(backend));
}
@ -1609,7 +1612,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
INDArray neg = Transforms.neg(n);
assertEquals(assertion, neg,getFailureMessage());
assertEquals(assertion, neg,getFailureMessage(backend));
}
@ -1622,13 +1625,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
double assertion = 5.47722557505;
double norm3 = n.norm2Number().doubleValue();
assertEquals(assertion, norm3, 1e-1,getFailureMessage());
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray row1 = row.getRow(1);
double norm2 = row1.norm2Number().doubleValue();
double assertion2 = 5.0f;
assertEquals(assertion2, norm2, 1e-1,getFailureMessage());
assertEquals(assertion2, norm2, 1e-1,getFailureMessage(backend));
Nd4j.setDataType(initialType);
}
@ -1640,14 +1643,14 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
float assertion = 5.47722557505f;
float norm3 = n.norm2Number().floatValue();
assertEquals(assertion, norm3, 1e-1,getFailureMessage());
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
INDArray row = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray row1 = row.getRow(1);
float norm2 = row1.norm2Number().floatValue();
float assertion2 = 5.0f;
assertEquals(assertion2, norm2, 1e-1,getFailureMessage());
assertEquals(assertion2, norm2, 1e-1,getFailureMessage(backend));
}
@ -1659,7 +1662,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
double sim = Transforms.cosineSim(vec1, vec2);
assertEquals(1, sim, 1e-1,getFailureMessage());
assertEquals(1, sim, 1e-1,getFailureMessage(backend));
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
@ -1675,14 +1678,14 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
double assertion = 2;
INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8});
INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer);
assertEquals(answer, scal,getFailureMessage());
assertEquals(answer, scal,getFailureMessage(backend));
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
INDArray row1 = row.getRow(1);
double assertion2 = 5.0;
INDArray answer2 = Nd4j.create(new double[] {15, 20});
INDArray scal2 = Nd4j.getBlasWrapper().scal(assertion2, row1);
assertEquals(answer2, scal2,getFailureMessage());
assertEquals(answer2, scal2,getFailureMessage(backend));
}
@ -2076,17 +2079,17 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray innerProduct = n.mmul(transposed);
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
assertEquals(scalar, innerProduct,getFailureMessage());
assertEquals(scalar, innerProduct,getFailureMessage(backend));
INDArray outerProduct = transposed.mmul(n);
assertEquals(true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape()),getFailureMessage());
assertEquals(true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape()),getFailureMessage(backend));
INDArray three = Nd4j.create(new double[] {3, 4});
INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2});
INDArray sliceRow = test.slice(0).getRow(1);
assertEquals(three, sliceRow,getFailureMessage());
assertEquals(three, sliceRow,getFailureMessage(backend));
INDArray twoSix = Nd4j.create(new double[] {2, 6}, new long[] {2, 1});
INDArray threeTwoSix = three.mmul(twoSix);
@ -2114,7 +2117,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
INDArray k1 = n1.transpose();
INDArray testVectorVector = k1.mmul(n1);
assertEquals(vectorVector, testVectorVector,getFailureMessage());
assertEquals(vectorVector, testVectorVector,getFailureMessage(backend));
}
@ -2204,7 +2207,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(linear.getDouble(0, 1), 1, 1e-1);
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSize(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
INDArray arr = Nd4j.create(4, 5);
@ -2357,7 +2361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testTensorDot(Nd4jBackend backend) {
INDArray oneThroughSixty = Nd4j.arange(60).reshape(3, 4, 5).castTo(DataType.DOUBLE);
@ -3051,10 +3056,10 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
public void testMeans(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray mean1 = a.mean(1);
assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage());
assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage());
assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage());
assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage());
assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage(backend));
assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage(backend));
assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
}
@ -3063,9 +3068,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSums(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage());
assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage());
assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage());
assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage(backend));
assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage(backend));
assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
}
@ -3438,7 +3443,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void largeInstantiation(Nd4jBackend backend) {
Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024 - 1)); // Still works; this can even be called as often as I want, allowing me even to spill over on disk
@ -3487,7 +3493,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(cSum, fSum); //Expect: 4,6. Getting [4, 4] for f order
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled //not relevant anymore
public void testAssignMixedC(Nd4jBackend backend) {
int[] shape1 = {3, 2, 2, 2, 2, 2};
@ -3787,7 +3794,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, result);
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2});
@ -3795,7 +3803,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
});
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation2(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2});
@ -3803,7 +3812,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
});
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation3(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10});
@ -3811,7 +3821,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
});
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation4(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3});
@ -3819,7 +3830,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
});
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsValidation5(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e');
@ -4975,7 +4987,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTadReduce3_5(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
INDArray initial = Nd4j.create(5, 10);
@ -6004,7 +6017,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testLogExpSum1(Nd4jBackend backend) {
INDArray matrix = Nd4j.create(3, 3);
@ -6019,7 +6033,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testLogExpSum2(Nd4jBackend backend) {
INDArray row = Nd4j.create(new double[]{1, 2, 3});
@ -6246,7 +6261,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testReshapeFailure(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2);
@ -6345,7 +6361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertArrayEquals(new long[]{3, 2}, newShape.shape());
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTranspose1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
@ -6360,7 +6377,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTranspose2(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
val scalar = Nd4j.scalar(2.f);
@ -6375,7 +6393,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
//@Disabled
public void testMatmul_128by256(Nd4jBackend backend) {
val mA = Nd4j.create(128, 156).assign(1.0f);
@ -6647,7 +6666,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp1, out1);
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBadReduce3Call(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
val x = Nd4j.create(400,20);
@ -7392,8 +7412,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(ez, z);
}
@Test()
public void testBroadcastInvalid(){
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBroadcastInvalid() {
assertThrows(IllegalStateException.class,() -> {
INDArray arr1 = Nd4j.ones(3,4,1);
@ -7656,7 +7677,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp, array);
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScatterUpdateShortcut_f1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
val array = Nd4j.create(DataType.FLOAT, 5, 2);
@ -8041,7 +8063,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp, out); //Failing here
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPullRowsFailure(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
val idxs = new int[]{0,2,3,4};
@ -8144,7 +8167,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
assertEquals(exp1, out1); //This is OK
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutRowValidation(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
val matrix = Nd4j.create(5, 10);
@ -8155,7 +8179,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutColumnValidation(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
val matrix = Nd4j.create(5, 10);
@ -8236,7 +8261,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScalarEq(){
public void testScalarEq(Nd4jBackend backend){
INDArray scalarRank2 = Nd4j.scalar(10.0).reshape(1,1);
INDArray scalarRank1 = Nd4j.scalar(10.0).reshape(1);
INDArray scalarRank0 = Nd4j.scalar(10.0);
@ -8273,7 +8298,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testType1(@TempDir Path testDir) throws IOException {
@Disabled
public void testType1(Nd4jBackend backend) throws IOException {
for (int i = 0; i < 10; ++i) {
INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100});
File dir = testDir.toFile();
@ -8295,7 +8321,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOnes(){
public void testOnes(Nd4jBackend backend){
INDArray arr = Nd4j.ones();
INDArray arr2 = Nd4j.ones(DataType.LONG);
assertEquals(0, arr.rank());
@ -8306,7 +8332,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testZeros(){
public void testZeros(Nd4jBackend backend){
INDArray arr = Nd4j.zeros();
INDArray arr2 = Nd4j.zeros(DataType.LONG);
assertEquals(0, arr.rank());
@ -8317,7 +8343,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testType2(@TempDir Path testDir) throws IOException {
@Disabled
public void testType2(Nd4jBackend backend) throws IOException {
for (int i = 0; i < 10; ++i) {
INDArray in1 = Nd4j.ones(DataType.UINT16);
File dir = testDir.toFile();

View File

@ -23,6 +23,7 @@ package org.nd4j.linalg;
import static org.junit.jupiter.api.Assertions.assertEquals;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@ -58,11 +59,12 @@ public class ToStringTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testToStringScalars(){
@Disabled
public void testToStringScalars(Nd4jBackend backend){
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};
for(int dt=0; dt<5; dt++ ) {
for(int dt = 0; dt < 5; dt++) {
for (int i = 0; i < 5; i++) {
long[] shape = ArrayUtil.nTimes(i, 1L);
INDArray scalar = Nd4j.scalar(1.0f).castTo(dataTypes[dt]).reshape(shape);

View File

@ -64,7 +64,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
}
@Test
@Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@ -79,7 +78,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
}
@Test
@Disabled
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@ -100,7 +98,8 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
}
@Test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCreateNpy3(Nd4jBackend backend) throws Exception {
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
assertEquals(8, arrCreate.length());
@ -111,8 +110,9 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
assertEquals(arrCreate.data().address(), pointer.address());
}
@Test
@Disabled // this is endless test
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testEndlessAllocation(Nd4jBackend backend) {
Nd4j.getEnvironment().setMaxSpecialMemory(1);
while (true) {
@ -121,9 +121,10 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
}
}
@Test
@Disabled("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time")
public void testAllocationLimits() throws Exception {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAllocationLimits(Nd4jBackend backend) throws Exception {
Nd4j.create(1);
val origDeviceLimit = Nd4j.getEnvironment().getDeviceLimit(0);

View File

@ -20,7 +20,6 @@
package org.nd4j.linalg.api;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;

View File

@ -59,7 +59,7 @@ public class Level1Test extends BaseNd4jTestWithBackends {
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray row = matrix.getRow(1);
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage());
assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage(backend));
}

View File

@ -70,8 +70,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
/**
* Testing level1 blas
*/
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
@ -89,8 +88,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
/**
* Testing level2 blas
*/
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation2(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> {
@ -109,8 +107,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
/**
* Testing level3 blas
*/
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBlasValidation3(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {

View File

@ -88,7 +88,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
float[] d1 = new float[] {1, 2, 3, 4};
DataBuffer d = Nd4j.createBuffer(d1);
float[] d2 = d.asFloat();
assertArrayEquals( d1, d2, 1e-1f,getFailureMessage());
assertArrayEquals( d1, d2, 1e-1f,getFailureMessage(backend));
}
@ -146,7 +146,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
d.put(0, 0.0);
float[] result = new float[] {0, 2, 3, 4};
d1 = d.asFloat();
assertArrayEquals(d1, result, 1e-1f,getFailureMessage());
assertArrayEquals(d1, result, 1e-1f,getFailureMessage(backend));
}
@ -156,12 +156,12 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(0, 3);
float[] data = new float[] {1, 2, 3};
assertArrayEquals(get, data, 1e-1f,getFailureMessage());
assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
float[] get2 = buffer.asFloat();
float[] allData = buffer.getFloatsAt(0, (int) buffer.length());
assertArrayEquals(get2, allData, 1e-1f,getFailureMessage());
assertArrayEquals(get2, allData, 1e-1f,getFailureMessage(backend));
}
@ -173,13 +173,13 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
float[] get = buffer.getFloatsAt(1, 3);
float[] data = new float[] {2, 3, 4};
assertArrayEquals(get, data, 1e-1f,getFailureMessage());
assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
float[] allButLast = new float[] {2, 3, 4, 5};
float[] allData = buffer.getFloatsAt(1, (int) buffer.length());
assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage());
assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage(backend));
}
@ -190,7 +190,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
public void testAsBytes(Nd4jBackend backend) {
INDArray arr = Nd4j.create(5);
byte[] d = arr.data().asBytes();
assertEquals(4 * 5, d.length,getFailureMessage());
assertEquals(4 * 5, d.length,getFailureMessage(backend));
INDArray rand = Nd4j.rand(3, 3);
rand.data().asBytes();

View File

@ -20,26 +20,18 @@
package org.nd4j.linalg.api.indexing;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.IntervalIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndexAll;
import org.nd4j.linalg.indexing.NewAxis;
import org.nd4j.linalg.indexing.PointIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.indexing.*;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.common.util.ArrayUtil;
import java.util.Arrays;
import java.util.Random;
@ -56,22 +48,22 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNegativeBounds() {
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
INDArray get = arr.get(NDArrayIndex.all(),interval);
INDArray assertion = Nd4j.create(new double[][]{
{1,2,3},
{6,7,8}
});
assertEquals(assertion,get);
public void testNegativeBounds(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
INDArray get = arr.get(NDArrayIndex.all(),interval);
INDArray assertion = Nd4j.create(new double[][]{
{1,2,3},
{6,7,8}
});
assertEquals(assertion,get);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNewAxis() {
public void testNewAxis(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
long[] shapeAssertion = {3, 2, 1, 1, 2};
@ -79,9 +71,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void broadcastBug() {
public void broadcastBug(Nd4jBackend backend) {
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
@ -91,9 +83,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIntervalsIn3D() {
public void testIntervalsIn3D(Nd4jBackend backend) {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
INDArray rest = arr.get(interval(1, 2), interval(0, 2), interval(0, 2));
@ -101,9 +93,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSmallInterval() {
public void testSmallInterval(Nd4jBackend backend) {
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
INDArray rest = arr.get(interval(1, 2), all(), all());
@ -111,9 +103,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAllWithNewAxisAndInterval() {
public void testAllWithNewAxisAndInterval(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
@ -121,9 +113,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion2, get2);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAllWithNewAxisInMiddle() {
public void testAllWithNewAxisInMiddle(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
@ -131,20 +123,20 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion2, get2);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAllWithNewAxis() {
public void testAllWithNewAxis(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray get = arr.get(newAxis(), all(), point(1));
INDArray assertion = Nd4j.create(new double[][] {{4, 5, 6}, {10, 11, 12}, {16, 17, 18}, {22, 23, 24}})
.reshape(1, 4, 3);
.reshape(1, 4, 3);
assertEquals(assertion, get);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIndexingWithMmul() {
public void testIndexingWithMmul(Nd4jBackend backend) {
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
// System.out.println(b);
@ -154,9 +146,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, c);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPointPointInterval() {
public void testPointPointInterval(Nd4jBackend backend) {
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
INDArray assertion = Nd4j.create(new double[][] {{5, 6}, {8, 9}});
@ -164,9 +156,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, get);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIntervalLowerBound() {
public void testIntervalLowerBound(Nd4jBackend backend) {
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}});
@ -176,9 +168,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetPointRowVector() {
public void testGetPointRowVector(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
INDArray arr2 = arr.get(point(0), interval(0, 100));
@ -187,9 +179,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSpecifiedIndexVector() {
public void testSpecifiedIndexVector(Nd4jBackend backend) {
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
INDArray get = rootMatrix.get(all(), new SpecifiedIndex(0, 2));
@ -205,9 +197,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testPutRowIndexing() {
public void testPutRowIndexing(Nd4jBackend backend) {
INDArray arr = Nd4j.ones(1, 10);
INDArray row = Nd4j.create(1, 10);
@ -216,9 +208,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(arr, row);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVectorIndexing2() {
public void testVectorIndexing2(Nd4jBackend backend) {
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
INDArray assertion = Nd4j.create(new double[] {2, 4});
assertEquals(assertion, wholeVector);
@ -232,9 +224,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testOffsetsC() {
public void testOffsetsC(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
assertEquals(3, NDArrayIndex.offset(arr, point(1), point(1)));
@ -249,9 +241,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIndexFor() {
public void testIndexFor(Nd4jBackend backend) {
long[] shape = {1, 2};
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
for (int i = 0; i < indexes.length; i++) {
@ -259,9 +251,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetScalar() {
public void testGetScalar(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
INDArray d = arr.get(point(1));
assertTrue(d.isScalar());
@ -269,26 +261,26 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testVectorIndexing() {
public void testVectorIndexing(Nd4jBackend backend) {
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
INDArray viewTest = arr.get(point(0), interval(1, 5));
assertEquals(assertion, viewTest);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNegativeIndices() {
public void testNegativeIndices(Nd4jBackend backend) {
INDArray test = Nd4j.create(10, 10, 10);
test.putScalar(new int[] {0, 0, -1}, 1.0);
assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber());
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetIndices2d() {
public void testGetIndices2d(Nd4jBackend backend) {
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
INDArray firstRow = twoByTwo.getRow(0);
INDArray secondRow = twoByTwo.getRow(1);
@ -305,9 +297,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetRow() {
public void testGetRow(Nd4jBackend backend) {
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
int[] toGet = {0, 1};
@ -323,9 +315,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetRowEdgeCase() {
public void testGetRowEdgeCase(Nd4jBackend backend) {
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
INDArray get = rowVec.getRow(0); //Returning shape [1,1]
@ -333,9 +325,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(rowVec, get);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetColumnEdgeCase() {
public void testGetColumnEdgeCase(Nd4jBackend backend) {
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
INDArray get = colVec.getColumn(0); //Returning shape [1,1]
@ -343,9 +335,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(colVec, get);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testConcatColumns() {
public void testConcatColumns(Nd4jBackend backend) {
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
INDArray concat = Nd4j.concat(1, input1, input2);
@ -353,18 +345,18 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, concat);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testGetIndicesVector() {
public void testGetIndicesVector(Nd4jBackend backend) {
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
INDArray test = Nd4j.create(new double[] {2, 3});
INDArray result = line.get(point(0), interval(1, 3));
assertEquals(test, result);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testArangeMul() {
public void testArangeMul(Nd4jBackend backend) {
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
INDArrayIndex index = interval(0, 2);
INDArray get = arange.get(index, index);
@ -374,7 +366,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
assertEquals(assertion, mul);
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIndexingThorough(){
long[] fullShape = {3,4,5,6,7};
@ -575,7 +567,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
return d;
}
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void debugging(){
long[] inShape = {3,4};

View File

@ -46,12 +46,13 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Slf4j
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void compareAfterWrite(Nd4jBackend backend) throws Exception {
int [] ranksToCheck = new int[] {0,1,2,3,4};
for (int i = 0; i < ranksToCheck.length; i++) {
// log.info("Checking read write arrays with rank " + ranksToCheck[i]);
@ -82,7 +83,7 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testNd4jReadWriteText(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testNd4jReadWriteText(Nd4jBackend backend) throws Exception {
File dir = testDir.toFile();
int count = 0;

View File

@ -38,11 +38,11 @@ import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays;
@Slf4j
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void compareAfterWrite(Nd4jBackend backend) throws Exception {
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
for (int i = 0; i < ranksToCheck.length; i++) {
log.info("Checking read write arrays with rank " + ranksToCheck[i]);

View File

@ -22,6 +22,7 @@ package org.nd4j.linalg.broadcast;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@ -135,7 +136,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
assertEquals(e, z);
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_1(Nd4jBackend backend) {
@ -146,7 +146,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
});
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_2(Nd4jBackend backend) {
@ -158,7 +157,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_3(Nd4jBackend backend) {
@ -170,16 +168,15 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void basicBroadcastFailureTest_4(Nd4jBackend backend) {
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
val z = x.addi(y);
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_5(Nd4jBackend backend) {
@ -191,7 +188,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void basicBroadcastFailureTest_6(Nd4jBackend backend) {
@ -249,9 +245,9 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
assertEquals(y, z);
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void emptyBroadcastTest_2(Nd4jBackend backend) {
val x = Nd4j.create(DataType.FLOAT, 1, 2);
val y = Nd4j.create(DataType.FLOAT, 0, 2);

View File

@ -37,7 +37,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class CompressionMagicTests extends BaseNd4jTestWithBackends {
@BeforeEach
public void setUp(Nd4jBackend backend) {
public void setUp() {
}

View File

@ -48,6 +48,7 @@ import java.util.Set;
public class DeconvTests extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@Override
public char ordering() {
@ -56,7 +57,7 @@ public class DeconvTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void compareKeras(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void compareKeras(Nd4jBackend backend) throws Exception {
File newFolder = testDir.toFile();
new ClassPathResource("keras/deconv/").copyDirectory(newFolder);

View File

@ -99,7 +99,8 @@ public class SpecialTests extends BaseNd4jTestWithBackends {
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testScalarShuffle1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
List<DataSet> listData = new ArrayList<>();

View File

@ -195,7 +195,8 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
assertEquals(exp, arrayX);
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testInplaceOp1(Nd4jBackend backend) {
assertThrows(ND4JIllegalStateException.class,() -> {
val arrayX = Nd4j.create(10, 10);

View File

@ -41,10 +41,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBalance(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testBalance(Nd4jBackend backend) throws Exception {
DataSetIterator iterator = new IrisDataSetIterator(10, 150);
File minibatches = new File(testDir.toFile(),"mini-batch-dir");
@ -62,7 +63,7 @@ public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMiniBatchBalanced(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
public void testMiniBatchBalanced(Nd4jBackend backend) throws Exception {
int miniBatchSize = 100;
DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150);

View File

@ -51,8 +51,10 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
@Slf4j
public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@TempDir Path testDir;
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testViewIterator(Nd4jBackend backend) {
DataSetIterator iter = new ViewIterator(new IrisDataSetIterator(150, 150).next(), 10);
@ -106,9 +108,9 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSplitTestAndTrain (Nd4jBackend backend) {
public void testSplitTestAndTrain(Nd4jBackend backend) {
INDArray labels = FeatureUtil.toOutcomeMatrix(new int[] {0, 0, 0, 0, 0, 0, 0, 0}, 1);
DataSet data = new DataSet(Nd4j.rand(8, 1), labels);
@ -116,7 +118,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
assertEquals(train.getTrain().getLabels().length(), 6);
SplitTestAndTrain train2 = data.splitTestAndTrain(6, new Random(1));
assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage());
assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage(backend));
DataSet x0 = new IrisDataSetIterator(150, 150).next();
SplitTestAndTrain testAndTrain = x0.splitTestAndTrain(10);
@ -144,7 +146,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
SplitTestAndTrain testAndTrainRng = x2.splitTestAndTrain(10, rngHere);
assertArrayEquals(testAndTrainRng.getTrain().getFeatures().shape(),
testAndTrain.getTrain().getFeatures().shape());
testAndTrain.getTrain().getFeatures().shape());
assertEquals(testAndTrainRng.getTrain().getFeatures(), testAndTrain.getTrain().getFeatures());
assertEquals(testAndTrainRng.getTrain().getLabels(), testAndTrain.getTrain().getLabels());
@ -154,13 +156,13 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testLabelCounts(Nd4jBackend backend) {
DataSet x0 = new IrisDataSetIterator(150, 150).next();
assertEquals(0, x0.get(0).outcome(),getFailureMessage());
assertEquals( 0, x0.get(1).outcome(),getFailureMessage());
assertEquals(2, x0.get(149).outcome(),getFailureMessage());
assertEquals(0, x0.get(0).outcome(),getFailureMessage(backend));
assertEquals( 0, x0.get(1).outcome(),getFailureMessage(backend));
assertEquals(2, x0.get(149).outcome(),getFailureMessage(backend));
Map<Integer, Double> counts = x0.labelCounts();
assertEquals(50, counts.get(0), 1e-1,getFailureMessage());
assertEquals(50, counts.get(1), 1e-1,getFailureMessage());
assertEquals(50, counts.get(2), 1e-1,getFailureMessage());
assertEquals(50, counts.get(0), 1e-1,getFailureMessage(backend));
assertEquals(50, counts.get(1), 1e-1,getFailureMessage(backend));
assertEquals(50, counts.get(2), 1e-1,getFailureMessage(backend));
}
@ -694,14 +696,14 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
INDArray expLabels3d = Nd4j.create(3, 3, 4);
expLabels3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
l3d1);
l3d1);
expLabels3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
NDArrayIndex.interval(0, 3)}, l3d2);
NDArrayIndex.interval(0, 3)}, l3d2);
INDArray expLM3d = Nd4j.create(3, 3, 4);
expLM3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
lm3d1);
lm3d1);
expLM3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
NDArrayIndex.interval(0, 3)}, lm3d2);
NDArrayIndex.interval(0, 3)}, lm3d2);
DataSet merged3d = DataSet.merge(Arrays.asList(ds3d1, ds3d2));
@ -752,52 +754,52 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testShuffleNd(Nd4jBackend backend) {
int numDims = 7;
int nLabels = 3;
Random r = new Random();
int numDims = 7;
int nLabels = 3;
Random r = new Random();
int[] shape = new int[numDims];
int entries = 1;
for (int i = 0; i < numDims; i++) {
//randomly generating shapes bigger than 1
shape[i] = r.nextInt(4) + 2;
entries *= shape[i];
}
int labels = shape[0] * nLabels;
int[] shape = new int[numDims];
int entries = 1;
for (int i = 0; i < numDims; i++) {
//randomly generating shapes bigger than 1
shape[i] = r.nextInt(4) + 2;
entries *= shape[i];
}
int labels = shape[0] * nLabels;
INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
DataSet ds = new DataSet(ds_data, ds_labels);
ds.shuffle();
DataSet ds = new DataSet(ds_data, ds_labels);
ds.shuffle();
//Checking Nd dataset which is the data
for (int dim = 1; dim < numDims; dim++) {
//get tensor along dimension - the order in every dimension but zero should be preserved
for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
//the difference between consecutive elements should be equal to the stride
for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
int f_element_diff = f_next_element - f_element;
assertEquals(f_element_diff, ds_data.stride(dim));
}
}
}
//Checking 2d, features
int dim = 1;
//Checking Nd dataset which is the data
for (int dim = 1; dim < numDims; dim++) {
//get tensor along dimension - the order in every dimension but zero should be preserved
for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
//the difference between consecutive elements should be equal to the stride
for (int i = 0, j = 1; j < nLabels; i++, j++) {
int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
int l_element_diff = l_next_element - l_element;
assertEquals(l_element_diff, ds_labels.stride(dim));
for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
int f_element_diff = f_next_element - f_element;
assertEquals(f_element_diff, ds_data.stride(dim));
}
}
}
//Checking 2d, features
int dim = 1;
//get tensor along dimension - the order in every dimension but zero should be preserved
for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
//the difference between consecutive elements should be equal to the stride
for (int i = 0, j = 1; j < nLabels; i++, j++) {
int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
int l_element_diff = l_next_element - l_element;
assertEquals(l_element_diff, ds_labels.stride(dim));
}
}
}
@ParameterizedTest
@ -936,9 +938,9 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
//Checking if the features and labels are equal
assertEquals(iDataSet.getFeatures(),
dsList.get(i).getFeatures().get(all(), all(), interval(0, minTSLength + i)));
dsList.get(i).getFeatures().get(all(), all(), interval(0, minTSLength + i)));
assertEquals(iDataSet.getLabels(),
dsList.get(i).getLabels().get(all(), all(), interval(0, minTSLength + i)));
dsList.get(i).getLabels().get(all(), all(), interval(0, minTSLength + i)));
}
}
@ -964,8 +966,8 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
for (boolean lMask : b) {
DataSet ds = new DataSet((features ? f : null),
(labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null),
(lMask ? lm : null));
(labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null),
(lMask ? lm : null));
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
@ -1009,7 +1011,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
boolean lMask = true;
DataSet ds = new DataSet((features ? f : null), (labels ? (labelsSameAsFeatures ? f : l) : null),
(fMask ? fm : null), (lMask ? lm : null));
(fMask ? fm : null), (lMask ? lm : null));
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
@ -1098,7 +1100,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
public void testDataSetMetaDataSerialization(Nd4jBackend backend) throws IOException {
for(boolean withMeta : new boolean[]{false, true}) {
// create simple data set with meta data object
@ -1129,7 +1131,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMultiDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend nd4jBackend) throws IOException {
public void testMultiDataSetMetaDataSerialization(Nd4jBackend nd4jBackend) throws IOException {
for(boolean withMeta : new boolean[]{false, true}) {
// create simple data set with meta data object

View File

@ -106,7 +106,8 @@ public class KFoldIteratorTest extends BaseNd4jTestWithBackends {
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void checkCornerCaseException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1),

View File

@ -21,27 +21,25 @@
package org.nd4j.linalg.dataset;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.nio.file.Path;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTestWithBackends {
@TempDir Path testDir;
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testMiniBatches(@TempDir Path testDir) throws Exception {
public void testMiniBatches(Nd4jBackend backend) throws Exception {
DataSet load = new IrisDataSetIterator(150, 150).next();
final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.toFile());
while (iter.hasNext())

View File

@ -39,8 +39,7 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
return 'c';
}
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_preConditionsIsNull_expect_NullPointerException(Nd4jBackend backend) {
assertThrows(NullPointerException.class,() -> {

View File

@ -41,8 +41,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
return 'c';
}
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_originalHeightIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@ -51,8 +50,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_originalWidthIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@ -61,8 +59,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_yStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@ -71,8 +68,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_xStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@ -81,8 +77,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@ -91,8 +86,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@ -101,8 +95,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
@ -111,8 +104,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
});
}
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
// Assemble

View File

@ -39,7 +39,8 @@ public class PermuteDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
return 'c';
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
assertThrows(NullPointerException.class,() -> {
// Assemble

View File

@ -20,7 +20,6 @@
package org.nd4j.linalg.dataset.api.preprocessor;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.linalg.BaseNd4jTestWithBackends;
@ -39,7 +38,8 @@ public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTestWithBacke
return 'c';
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
assertThrows(NullPointerException.class,() -> {
// Assemble

View File

@ -139,7 +139,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
INDArray actualResult = data.mean(0);
INDArray expectedResult = Nd4j.create(new double[] {3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6.,
6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6.}, new int[] {2, 4, 4});
assertEquals(expectedResult, actualResult,getFailureMessage());
assertEquals(expectedResult, actualResult,getFailureMessage(backend));
}
@ -154,7 +154,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
INDArray actualResult = data.var(false, 0);
INDArray expectedResult = Nd4j.create(new double[] {1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4.,
4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4.}, new long[] {2, 4, 4});
assertEquals(expectedResult, actualResult,getFailureMessage());
assertEquals(expectedResult, actualResult,getFailureMessage(backend));
}
@ParameterizedTest

View File

@ -83,8 +83,7 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
}
}
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAccessException_1(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {
@ -96,8 +95,7 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
}
@Test()
@ParameterizedTest
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAccessException_2(Nd4jBackend backend) {
assertThrows(IllegalStateException.class,() -> {

View File

@ -384,7 +384,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
assertEquals(exp, arrayZ);
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTypesValidation_1(Nd4jBackend backend) {
assertThrows(IllegalArgumentException.class,() -> {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG);
@ -397,7 +399,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTypesValidation_2(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
@ -412,7 +415,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
}
@Test()
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTypesValidation_3(Nd4jBackend backend) {
assertThrows(RuntimeException.class,() -> {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
@ -422,6 +426,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
}
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testTypesValidation_4(Nd4jBackend backend) {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.DOUBLE);
@ -485,7 +491,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testBoolFloatCast2(){
public void testBoolFloatCast2(Nd4jBackend backend){
val first = Nd4j.zeros(DataType.FLOAT, 3, 5000);
INDArray asBool = first.castTo(DataType.BOOL);
INDArray not = Transforms.not(asBool); //
@ -516,7 +522,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testAssignScalarSimple(){
public void testAssignScalarSimple(Nd4jBackend backend){
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
INDArray arr = Nd4j.scalar(dt, 10.0);
arr.assign(2.0);
@ -526,7 +532,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testSimple(){
public void testSimple(Nd4jBackend backend){
Nd4j.create(1);
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) {
// System.out.println("----- " + dt + " -----");
@ -551,7 +557,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testWorkspaceBool(){
public void testWorkspaceBool(Nd4jBackend backend){
val conf = WorkspaceConfiguration.builder().minSize(10 * 1024 * 1024)
.overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE)
.policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL)
@ -559,7 +565,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS");
for( int i=0; i<10; i++ ) {
for( int i = 0; i < 10; i++ ) {
try (val workspace = (Nd4jWorkspace)ws.notifyScopeEntered() ) {
val bool = Nd4j.create(DataType.BOOL, 1, 10);
val dbl = Nd4j.create(DataType.DOUBLE, 1, 10);
@ -574,8 +580,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
}
}
@Test
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@Disabled
public void testArrayCreationFromPointer(Nd4jBackend backend) {
val source = Nd4j.create(new double[]{1, 2, 3, 4, 5});

View File

@ -40,13 +40,13 @@ public class NativeBlasTests extends BaseNd4jTestWithBackends {
@BeforeEach
public void setUp(Nd4jBackend backend) {
public void setUp() {
Nd4j.getExecutioner().enableDebugMode(true);
Nd4j.getExecutioner().enableVerboseMode(true);
}
@AfterEach
public void setDown(Nd4jBackend backend) {
public void setDown() {
Nd4j.getExecutioner().enableDebugMode(false);
Nd4j.getExecutioner().enableVerboseMode(false);
}

View File

@ -77,18 +77,18 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
double sim = Transforms.cosineSim(vec1, vec2);
assertEquals( 1, sim, 1e-1,getFailureMessage());
assertEquals( 1, sim, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testCosineDistance(){
public void testCosineDistance(Nd4jBackend backend){
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3});
INDArray vec2 = Nd4j.create(new float[] {3, 5, 7});
// 1-17*sqrt(2/581)
double distance = Transforms.cosineDistance(vec1, vec2);
assertEquals(0.0025851, distance, 1e-7,getFailureMessage());
assertEquals(0.0025851, distance, 1e-7,getFailureMessage(backend));
}
@ParameterizedTest
@ -97,7 +97,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray arr = Nd4j.create(new double[] {55, 55});
INDArray arr2 = Nd4j.create(new double[] {60, 60});
double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).z().getDouble(0);
assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage());
assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@ -137,7 +137,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi();
INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6);
Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1));
assertEquals(scalarMax, postMax,getFailureMessage());
assertEquals(scalarMax, postMax,getFailureMessage(backend));
}
@ParameterizedTest
@ -147,14 +147,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1));
for (int i = 0; i < linspace.length(); i++) {
double val = linspace.getDouble(i);
assertTrue( val >= 0 && val <= 1,getFailureMessage());
assertTrue( val >= 0 && val <= 1,getFailureMessage(backend));
}
INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4));
for (int i = 0; i < linspace2.length(); i++) {
double val = linspace2.getDouble(i);
assertTrue( val >= 2 && val <= 4,getFailureMessage());
assertTrue( val >= 2 && val <= 4,getFailureMessage(backend));
}
}
@ -163,7 +163,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
public void testNormMax(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).z().getDouble(0);
assertEquals(4, normMax, 1e-1,getFailureMessage());
assertEquals(4, normMax, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@ -187,7 +187,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
public void testNorm2(Nd4jBackend backend) {
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).z().getDouble(0);
assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage());
assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@ -198,7 +198,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray xDup = x.dup();
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
assertEquals(solution, x,getFailureMessage());
assertEquals(solution, x,getFailureMessage(backend));
}
@ParameterizedTest
@ -221,13 +221,13 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray xDup = x.dup();
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
assertEquals(solution, x,getFailureMessage());
assertEquals(solution, x,getFailureMessage(backend));
Sum acc = new Sum(x.dup());
opExecutioner.exec(acc);
assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
Prod prod = new Prod(x.dup());
opExecutioner.exec(prod);
assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
}
@ -275,7 +275,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Variance variance = new Variance(x.dup(), true);
opExecutioner.exec(variance);
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
}
@ -284,14 +284,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIamax(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage());
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
}
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
public void testIamax2(Nd4jBackend backend) {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage());
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
val op = new ArgAmax(linspace);
int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0);
@ -307,11 +307,11 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Mean mean = new Mean(x);
opExecutioner.exec(mean);
assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
Variance variance = new Variance(x.dup(), true);
opExecutioner.exec(variance);
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
}
@ParameterizedTest
@ -321,7 +321,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
val softMax = new SoftMax(arr);
opExecutioner.exec((CustomOp) softMax);
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
}
@ -332,7 +332,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Pow pow = new Pow(oneThroughSix, 2);
Nd4j.getExecutioner().exec(pow);
INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36});
assertEquals(answer, pow.z(),getFailureMessage());
assertEquals(answer, pow.z(),getFailureMessage(backend));
}
@ -384,7 +384,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
Log log = new Log(slice);
opExecutioner.exec(log);
INDArray assertion = Nd4j.create(new double[] {0., 1.09861229, 1.60943791});
assertEquals(assertion, slice,getFailureMessage());
assertEquals(assertion, slice,getFailureMessage(backend));
}
@ParameterizedTest
@ -572,7 +572,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
expected[i] = (float) Math.exp(slice.getDouble(i));
Exp exp = new Exp(slice);
opExecutioner.exec(exp);
assertEquals( Nd4j.create(expected), slice,getFailureMessage());
assertEquals( Nd4j.create(expected), slice,getFailureMessage(backend));
}
@ParameterizedTest
@ -582,7 +582,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
val softMax = new SoftMax(arr);
opExecutioner.exec((CustomOp) softMax);
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
}
@ParameterizedTest

Some files were not shown because too many files have changed in this diff Show More