Unify nd4j test profiles, get rid of old modules, fix more parameter issues with junit 5 tests
parent
e0077c38a9
commit
ad4f47096c
|
@ -31,7 +31,7 @@ jobs:
|
|||
protoc --version
|
||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
windows-x86_64:
|
||||
runs-on: windows-2019
|
||||
|
@ -44,7 +44,7 @@ jobs:
|
|||
run: |
|
||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
|
||||
|
||||
|
@ -60,5 +60,5 @@ jobs:
|
|||
run: |
|
||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ jobs:
|
|||
protoc --version
|
||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
windows-x86_64:
|
||||
runs-on: windows-2019
|
||||
|
@ -44,7 +44,7 @@ jobs:
|
|||
run: |
|
||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
|
||||
|
||||
|
@ -60,5 +60,5 @@ jobs:
|
|||
run: |
|
||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||
|
||||
|
|
|
@ -34,5 +34,5 @@ jobs:
|
|||
cmake --version
|
||||
protoc --version
|
||||
export OMP_NUM_THREADS=1
|
||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptest-nd4j-native --also-make clean test
|
||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Pnd4j-tests-cpu --also-make clean test
|
||||
|
||||
|
|
|
@ -109,10 +109,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -30,6 +30,8 @@ import org.datavec.api.writable.Writable;
|
|||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.loader.FileBatch;
|
||||
import java.io.File;
|
||||
|
@ -40,13 +42,16 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
@DisplayName("File Batch Record Reader Test")
|
||||
class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||
public class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||
@TempDir Path testDir;
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@DisplayName("Test Csv")
|
||||
void testCsv(@TempDir Path testDir) throws Exception {
|
||||
void testCsv(Nd4jBackend backend) throws Exception {
|
||||
// This is an unrealistic use case - one line/record per CSV
|
||||
File baseDir = testDir.toFile();
|
||||
List<File> fileList = new ArrayList<>();
|
||||
|
@ -75,9 +80,10 @@ class FileBatchRecordReaderTest extends BaseND4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@DisplayName("Test Csv Sequence")
|
||||
void testCsvSequence(@TempDir Path testDir) throws Exception {
|
||||
void testCsvSequence(Nd4jBackend backend) throws Exception {
|
||||
// CSV sequence - 3 lines per file, 10 files
|
||||
File baseDir = testDir.toFile();
|
||||
List<File> fileList = new ArrayList<>();
|
||||
|
|
|
@ -21,7 +21,6 @@ package org.datavec.api.transform.ops;
|
|||
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
|
|
@ -60,10 +60,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -119,10 +119,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -59,10 +59,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -57,10 +57,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -65,10 +65,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -61,25 +61,18 @@
|
|||
<artifactId>nd4j-common</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-geo</artifactId>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>python4j-numpy</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-python</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -29,7 +29,6 @@ import org.datavec.api.transform.reduce.Reducer;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.schema.SequenceSchema;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.datavec.python.PythonTransform;
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
@ -39,7 +38,6 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
|||
import java.util.*;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import static java.time.Duration.ofMillis;
|
||||
import static org.junit.jupiter.api.Assertions.assertTimeout;
|
||||
|
||||
|
@ -166,37 +164,8 @@ class ExecutionTest {
|
|||
List<List<Writable>> out = outRdd;
|
||||
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
||||
out = new ArrayList<>(out);
|
||||
Collections.sort(out, new Comparator<List<Writable>>() {
|
||||
|
||||
@Override
|
||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
||||
}
|
||||
});
|
||||
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
|
||||
assertEquals(expOut, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
|
||||
@DisplayName("Test Python Execution Ndarray")
|
||||
void testPythonExecutionNdarray() {
|
||||
assertTimeout(ofMillis(60000), () -> {
|
||||
Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
|
||||
TransformProcess transformProcess = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
|
||||
List<List<Writable>> functions = new ArrayList<>();
|
||||
List<Writable> firstRow = new ArrayList<>();
|
||||
INDArray firstArr = Nd4j.linspace(1, 4, 4);
|
||||
INDArray secondArr = Nd4j.linspace(1, 4, 4);
|
||||
firstRow.add(new NDArrayWritable(firstArr));
|
||||
firstRow.add(new NDArrayWritable(secondArr));
|
||||
functions.add(firstRow);
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
|
||||
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
|
||||
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
|
||||
INDArray expected = Transforms.sin(firstArr);
|
||||
INDArray secondExpected = Transforms.cos(secondArr);
|
||||
assertEquals(expected, firstResult);
|
||||
assertEquals(secondExpected, secondResult);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,155 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.local.transforms.transform;
|
||||
|
||||
import org.datavec.api.transform.ColumnType;
|
||||
import org.datavec.api.transform.Transform;
|
||||
import org.datavec.api.transform.geo.LocationType;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
|
||||
import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
|
||||
import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
|
||||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
/**
|
||||
* @author saudet
|
||||
*/
|
||||
public class TestGeoTransforms {
|
||||
|
||||
@BeforeAll
|
||||
public static void beforeClass() throws Exception {
|
||||
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
|
||||
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
|
||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void afterClass(){
|
||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testCoordinatesDistanceTransform() throws Exception {
|
||||
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
|
||||
.build();
|
||||
|
||||
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
|
||||
transform.setInputSchema(schema);
|
||||
|
||||
Schema out = transform.transform(schema);
|
||||
assertEquals(4, out.numColumns());
|
||||
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
|
||||
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
|
||||
out.getColumnTypes());
|
||||
|
||||
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
|
||||
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
|
||||
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
|
||||
new DoubleWritable(Math.sqrt(160))),
|
||||
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
|
||||
new Text("10|5"))));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIPAddressToCoordinatesTransform() throws Exception {
|
||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
||||
|
||||
Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
|
||||
transform.setInputSchema(schema);
|
||||
|
||||
Schema out = transform.transform(schema);
|
||||
|
||||
assertEquals(1, out.getColumnMetaData().size());
|
||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
||||
|
||||
String in = "81.2.69.160";
|
||||
double latitude = 51.5142;
|
||||
double longitude = -0.0931;
|
||||
|
||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
||||
assertEquals(1, writables.size());
|
||||
String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
||||
assertEquals(2, coordinates.length);
|
||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
||||
|
||||
//Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
ObjectOutputStream oos = new ObjectOutputStream(baos);
|
||||
oos.writeObject(transform);
|
||||
|
||||
byte[] bytes = baos.toByteArray();
|
||||
|
||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
||||
ObjectInputStream ois = new ObjectInputStream(bais);
|
||||
|
||||
Transform deserialized = (Transform) ois.readObject();
|
||||
writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
|
||||
assertEquals(1, writables.size());
|
||||
coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
||||
//System.out.println(Arrays.toString(coordinates));
|
||||
assertEquals(2, coordinates.length);
|
||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIPAddressToLocationTransform() throws Exception {
|
||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
||||
LocationType[] locationTypes = LocationType.values();
|
||||
String in = "81.2.69.160";
|
||||
String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
|
||||
"51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
|
||||
|
||||
for (int i = 0; i < locationTypes.length; i++) {
|
||||
LocationType locationType = locationTypes[i];
|
||||
String location = locations[i];
|
||||
|
||||
Transform transform = new IPAddressToLocationTransform("column", locationType);
|
||||
transform.setInputSchema(schema);
|
||||
|
||||
Schema out = transform.transform(schema);
|
||||
|
||||
assertEquals(1, out.getColumnMetaData().size());
|
||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
||||
|
||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
||||
assertEquals(1, writables.size());
|
||||
assertEquals(location, writables.get(0).toString());
|
||||
//System.out.println(location);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,386 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.local.transforms.transform;
|
||||
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.api.transform.condition.Condition;
|
||||
import org.datavec.api.transform.filter.ConditionFilter;
|
||||
import org.datavec.api.transform.filter.Filter;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
|
||||
import org.datavec.api.writable.*;
|
||||
import org.datavec.python.PythonCondition;
|
||||
import org.datavec.python.PythonTransform;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
import static org.datavec.api.transform.schema.Schema.Builder;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@NotThreadSafe
|
||||
public class TestPythonTransformProcess {
|
||||
|
||||
|
||||
@Test()
|
||||
public void testStringConcat() throws Exception{
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnString("col1")
|
||||
.addColumnString("col2");
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnString("col3");
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build()
|
||||
).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
|
||||
|
||||
List<Writable> outputs = tp.execute(inputs);
|
||||
assertEquals((outputs.get(0)).toString(), "Hello ");
|
||||
assertEquals((outputs.get(1)).toString(), "World!");
|
||||
assertEquals((outputs.get(2)).toString(), "Hello World!");
|
||||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@Timeout(60000L)
|
||||
public void testMixedTypes() throws Exception {
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnInteger("col1")
|
||||
.addColumnFloat("col2")
|
||||
.addColumnString("col3")
|
||||
.addColumnDouble("col4");
|
||||
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnInteger("col5");
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.inputSchema(initialSchema)
|
||||
.build() ).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList(new IntWritable(10),
|
||||
new FloatWritable(3.5f),
|
||||
new Text("5"),
|
||||
new DoubleWritable(2.0)
|
||||
);
|
||||
|
||||
List<Writable> outputs = tp.execute(inputs);
|
||||
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@Timeout(60000L)
|
||||
public void testNDArray() throws Exception {
|
||||
long[] shape = new long[]{3, 2};
|
||||
INDArray arr1 = Nd4j.rand(shape);
|
||||
INDArray arr2 = Nd4j.rand(shape);
|
||||
|
||||
INDArray expectedOutput = arr1.add(arr2);
|
||||
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnNDArray("col1", shape)
|
||||
.addColumnNDArray("col2", shape);
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnNDArray("col3", shape);
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build() ).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList(
|
||||
(Writable)
|
||||
new NDArrayWritable(arr1),
|
||||
new NDArrayWritable(arr2)
|
||||
);
|
||||
|
||||
List<Writable> outputs = tp.execute(inputs);
|
||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
||||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@Timeout(60000L)
|
||||
public void testNDArray2() throws Exception {
|
||||
long[] shape = new long[]{3, 2};
|
||||
INDArray arr1 = Nd4j.rand(shape);
|
||||
INDArray arr2 = Nd4j.rand(shape);
|
||||
|
||||
INDArray expectedOutput = arr1.add(arr2);
|
||||
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnNDArray("col1", shape)
|
||||
.addColumnNDArray("col2", shape);
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnNDArray("col3", shape);
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build() ).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList(
|
||||
(Writable)
|
||||
new NDArrayWritable(arr1),
|
||||
new NDArrayWritable(arr2)
|
||||
);
|
||||
|
||||
List<Writable> outputs = tp.execute(inputs);
|
||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
||||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@Timeout(60000L)
|
||||
public void testNDArrayMixed() throws Exception{
|
||||
long[] shape = new long[]{3, 2};
|
||||
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
|
||||
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
|
||||
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
|
||||
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnNDArray("col1", shape)
|
||||
.addColumnNDArray("col2", shape);
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnNDArray("col3", shape);
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build()
|
||||
).build();
|
||||
|
||||
List<Writable> inputs = Arrays.asList(
|
||||
(Writable)
|
||||
new NDArrayWritable(arr1),
|
||||
new NDArrayWritable(arr2)
|
||||
);
|
||||
|
||||
List<Writable> outputs = tp.execute(inputs);
|
||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
||||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@Timeout(60000L)
|
||||
public void testPythonFilter() {
|
||||
Schema schema = new Builder().addColumnInteger("column").build();
|
||||
|
||||
Condition condition = new PythonCondition(
|
||||
"f = lambda: column < 0"
|
||||
);
|
||||
|
||||
condition.setInputSchema(schema);
|
||||
|
||||
Filter filter = new ConditionFilter(condition);
|
||||
|
||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
|
||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
|
||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
|
||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
|
||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
|
||||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@Timeout(60000L)
|
||||
public void testPythonFilterAndTransform() throws Exception {
|
||||
Builder schemaBuilder = new Builder();
|
||||
schemaBuilder
|
||||
.addColumnInteger("col1")
|
||||
.addColumnFloat("col2")
|
||||
.addColumnString("col3")
|
||||
.addColumnDouble("col4");
|
||||
|
||||
Schema initialSchema = schemaBuilder.build();
|
||||
schemaBuilder.addColumnString("col6");
|
||||
Schema finalSchema = schemaBuilder.build();
|
||||
|
||||
Condition condition = new PythonCondition(
|
||||
"f = lambda: col1 < 0 and col2 > 10.0"
|
||||
);
|
||||
|
||||
condition.setInputSchema(initialSchema);
|
||||
|
||||
Filter filter = new ConditionFilter(condition);
|
||||
|
||||
String pythonCode = "col6 = str(col1 + col2)";
|
||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
||||
PythonTransform.builder().code(pythonCode)
|
||||
.outputSchema(finalSchema)
|
||||
.build()
|
||||
).filter(
|
||||
filter
|
||||
).build();
|
||||
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(
|
||||
Arrays.asList(
|
||||
(Writable)
|
||||
new IntWritable(5),
|
||||
new FloatWritable(3.0f),
|
||||
new Text("abcd"),
|
||||
new DoubleWritable(2.1))
|
||||
);
|
||||
inputs.add(
|
||||
Arrays.asList(
|
||||
(Writable)
|
||||
new IntWritable(-3),
|
||||
new FloatWritable(3.0f),
|
||||
new Text("abcd"),
|
||||
new DoubleWritable(2.1))
|
||||
);
|
||||
inputs.add(
|
||||
Arrays.asList(
|
||||
(Writable)
|
||||
new IntWritable(5),
|
||||
new FloatWritable(11.2f),
|
||||
new Text("abcd"),
|
||||
new DoubleWritable(2.1))
|
||||
);
|
||||
|
||||
LocalTransformExecutor.execute(inputs,tp);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testPythonTransformNoOutputSpecified() throws Exception {
|
||||
PythonTransform pythonTransform = PythonTransform.builder()
|
||||
.code("a += 2; b = 'hello world'")
|
||||
.returnAllInputs(true)
|
||||
.build();
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
|
||||
Schema inputSchema = new Builder()
|
||||
.addColumnInteger("a")
|
||||
.build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||
.transform(pythonTransform)
|
||||
.build();
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||
assertEquals(3,execute.get(0).get(0).toInt());
|
||||
assertEquals("hello world",execute.get(0).get(1).toString());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNumpyTransform() {
|
||||
PythonTransform pythonTransform = PythonTransform.builder()
|
||||
.code("a += 2; b = 'hello world'")
|
||||
.returnAllInputs(true)
|
||||
.build();
|
||||
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
|
||||
Schema inputSchema = new Builder()
|
||||
.addColumnNDArray("a",new long[]{1,1})
|
||||
.build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||
.transform(pythonTransform)
|
||||
.build();
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||
assertFalse(execute.isEmpty());
|
||||
assertNotNull(execute.get(0));
|
||||
assertNotNull(execute.get(0).get(0));
|
||||
assertNotNull(execute.get(0).get(1));
|
||||
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
|
||||
assertEquals("hello world",execute.get(0).get(1).toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWithSetupRun() throws Exception {
|
||||
|
||||
PythonTransform pythonTransform = PythonTransform.builder()
|
||||
.code("five=None\n" +
|
||||
"def setup():\n" +
|
||||
" global five\n"+
|
||||
" five = 5\n\n" +
|
||||
"def run(a, b):\n" +
|
||||
" c = a + b + five\n"+
|
||||
" return {'c':c}\n\n")
|
||||
.returnAllInputs(true)
|
||||
.setupAndRun(true)
|
||||
.build();
|
||||
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
|
||||
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
|
||||
Schema inputSchema = new Builder()
|
||||
.addColumnNDArray("a",new long[]{1,1})
|
||||
.addColumnNDArray("b", new long[]{1, 1})
|
||||
.build();
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||
.transform(pythonTransform)
|
||||
.build();
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||
assertFalse(execute.isEmpty());
|
||||
assertNotNull(execute.get(0));
|
||||
assertNotNull(execute.get(0).get(0));
|
||||
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
|
||||
}
|
||||
|
||||
}
|
|
@ -128,10 +128,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -92,6 +92,10 @@
|
|||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-api</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.vintage</groupId>
|
||||
<artifactId>junit-vintage-engine</artifactId>
|
||||
|
@ -154,7 +158,7 @@
|
|||
<skip>${skipTestResourceEnforcement}</skip>
|
||||
<rules>
|
||||
<requireActiveProfile>
|
||||
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles>
|
||||
<profiles>nd4j-tests-cpu,nd4j-tests-cuda</profiles>
|
||||
<all>false</all>
|
||||
</requireActiveProfile>
|
||||
</rules>
|
||||
|
@ -163,23 +167,6 @@
|
|||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<configuration>
|
||||
<argLine></argLine>
|
||||
|
||||
<!--
|
||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
||||
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
|
||||
function correctly without this configuration.
|
||||
For example, tests for custom transforms (where the custom transform is defined in the test directory)
|
||||
will fail due to the custom transform not being found on the classpath.
|
||||
http://maven.apache.org/surefire/maven-surefire-plugin/examples/class-loading.html
|
||||
-->
|
||||
<useSystemClassLoader>true</useSystemClassLoader>
|
||||
<useManifestOnlyJar>false</useManifestOnlyJar>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.eclipse.m2e</groupId>
|
||||
<artifactId>lifecycle-mapping</artifactId>
|
||||
|
@ -249,7 +236,7 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
@ -266,7 +253,7 @@
|
|||
</dependencies>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
@ -286,9 +273,6 @@
|
|||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<configuration>
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
|
|
@ -64,7 +64,7 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
@ -75,7 +75,7 @@
|
|||
</dependencies>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
|
|
@ -56,10 +56,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -166,7 +166,7 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
@ -177,7 +177,7 @@
|
|||
</dependencies>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
|
|
@ -23,7 +23,6 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
||||
|
@ -34,7 +33,6 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@DisplayName("Early Termination Data Set Iterator Test")
|
||||
class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
|
||||
|
|
|
@ -21,19 +21,16 @@ package org.deeplearning4j.datasets.iterator;
|
|||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@DisplayName("Early Termination Multi Data Set Iterator Test")
|
||||
|
|
|
@ -34,7 +34,6 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
|||
import org.junit.jupiter.api.Disabled;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -46,7 +45,6 @@ import java.util.Random;
|
|||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@Disabled
|
||||
@DisplayName("Attention Layer Test")
|
||||
|
|
|
@ -105,11 +105,12 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<inherited>true</inherited>
|
||||
<configuration>
|
||||
<skip>true</skip>
|
||||
</configuration>
|
||||
|
@ -118,7 +119,7 @@
|
|||
</build>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -56,10 +56,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -50,10 +50,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -45,10 +45,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -54,10 +54,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -112,7 +112,7 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
@ -123,7 +123,7 @@
|
|||
</dependencies>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
|
|
@ -72,10 +72,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -306,7 +306,7 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
@ -317,7 +317,7 @@
|
|||
</dependencies>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
|
|
@ -101,10 +101,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -49,10 +49,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -127,10 +127,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -102,7 +102,7 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
@ -113,7 +113,7 @@
|
|||
</dependencies>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
|
|
@ -99,10 +99,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -44,10 +44,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -89,10 +89,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -88,10 +88,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -90,10 +90,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -105,10 +105,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -182,10 +182,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -77,10 +77,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -104,10 +104,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -141,10 +141,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -79,10 +79,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -426,10 +426,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -44,10 +44,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -87,10 +87,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -117,10 +117,10 @@
|
|||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -143,6 +143,10 @@
|
|||
</extensions>
|
||||
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-enforcer-plugin</artifactId>
|
||||
|
@ -158,7 +162,7 @@
|
|||
<skip>${skipBackendChoice}</skip>
|
||||
<rules>
|
||||
<requireActiveProfile>
|
||||
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles>
|
||||
<profiles>nd4j-tests-cpu,nd4j-tests-cuda</profiles>
|
||||
<all>false</all>
|
||||
</requireActiveProfile>
|
||||
</rules>
|
||||
|
@ -227,43 +231,6 @@
|
|||
</plugin>
|
||||
</plugins>
|
||||
|
||||
<pluginManagement>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<inherited>true</inherited>
|
||||
<configuration>
|
||||
<!--
|
||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
||||
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
|
||||
function correctly without this configuration.
|
||||
For example, tests for custom layers (where the custom layer is defined in the test directory)
|
||||
will fail due to the custom layer not being found on the classpath.
|
||||
http://maven.apache.org/surefire/maven-surefire-plugin/examples/class-loading.html
|
||||
-->
|
||||
<useSystemClassLoader>true</useSystemClassLoader>
|
||||
<useManifestOnlyJar>false</useManifestOnlyJar>
|
||||
<argLine> -Dfile.encoding=UTF-8 -Xmx8g "</argLine>
|
||||
<includes>
|
||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
||||
<include>*.java</include>
|
||||
<include>**/*.java</include>
|
||||
</includes>
|
||||
</configuration>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.maven.surefire</groupId>
|
||||
<artifactId>surefire-junit-platform</artifactId>
|
||||
<version>${maven-surefire-plugin.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.eclipse.m2e</groupId>
|
||||
<artifactId>lifecycle-mapping</artifactId>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</pluginManagement>
|
||||
</build>
|
||||
|
||||
<profiles>
|
||||
|
@ -290,10 +257,10 @@
|
|||
<module>deeplearning4j-cuda</module>
|
||||
</modules>
|
||||
</profile>
|
||||
<!-- For running unit tests with nd4j-native: "mvn clean test -P test-nd4j-native"
|
||||
<!-- For running unit tests with nd4j-native: "mvn clean test -P nd4j-tests-cpu"
|
||||
Note that this excludes DL4J-cuda -->
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
<activation>
|
||||
<activeByDefault>false</activeByDefault>
|
||||
</activation>
|
||||
|
@ -311,70 +278,10 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<inherited>true</inherited>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<version>${junit.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
<version>${junit.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.maven.surefire</groupId>
|
||||
<artifactId>surefire-junit-platform</artifactId>
|
||||
<version>${maven-surefire-plugin.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<configuration>
|
||||
<environmentVariables>
|
||||
|
||||
</environmentVariables>
|
||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
||||
<includes>
|
||||
<include>*.java</include>
|
||||
<include>**/*.java</include>
|
||||
<include>**/Test*.java</include>
|
||||
<include>**/*Test.java</include>
|
||||
<include>**/*TestCase.java</include>
|
||||
</includes>
|
||||
<junitArtifactName>org.junit.jupiter:junit-jupiter-engine</junitArtifactName>
|
||||
<systemPropertyVariables>
|
||||
<org.nd4j.linalg.defaultbackend>
|
||||
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
||||
</org.nd4j.linalg.defaultbackend>
|
||||
<org.nd4j.linalg.tests.backendstorun>
|
||||
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
||||
</org.nd4j.linalg.tests.backendstorun>
|
||||
</systemPropertyVariables>
|
||||
<!--
|
||||
Maximum heap size was set to 8g, as a minimum required value for tests run.
|
||||
Depending on a build machine, default value is not always enough.
|
||||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine></argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
<activation>
|
||||
<activeByDefault>false</activeByDefault>
|
||||
</activation>
|
||||
|
@ -392,43 +299,6 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<!-- Default to ALL modules here, unlike nd4j-native -->
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<version>${maven-surefire-plugin.version}</version>
|
||||
<configuration>
|
||||
<environmentVariables>
|
||||
</environmentVariables>
|
||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
||||
<includes>
|
||||
<include>*.java</include>
|
||||
<include>**/*.java</include>
|
||||
<include>**/Test*.java</include>
|
||||
<include>**/*Test.java</include>
|
||||
<include>**/*TestCase.java</include>
|
||||
</includes>
|
||||
<junitArtifactName>org.junit.jupiter:junit-jupiter</junitArtifactName>
|
||||
<systemPropertyVariables>
|
||||
<org.nd4j.linalg.defaultbackend>
|
||||
org.nd4j.linalg.jcublas.JCublasBackend
|
||||
</org.nd4j.linalg.defaultbackend>
|
||||
<org.nd4j.linalg.tests.backendstorun>
|
||||
org.nd4j.linalg.jcublas.JCublasBackend
|
||||
</org.nd4j.linalg.tests.backendstorun>
|
||||
</systemPropertyVariables>
|
||||
<!--
|
||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
||||
Depending on a build machine, default value is not always enough.
|
||||
-->
|
||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
||||
</configuration>
|
||||
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -5,7 +5,7 @@ Linux
|
|||
[INFO] Total time: 14.610 s
|
||||
[INFO] Finished at: 2021-03-06T15:35:28+09:00
|
||||
[INFO] ------------------------------------------------------------------------
|
||||
[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist.
|
||||
[WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
|
||||
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
|
||||
[ERROR]
|
||||
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
|
||||
|
@ -749,7 +749,7 @@ make[1]: Leaving directory '/c/Users/agibs/Documents/GitHub/eclipse-deeplearning
|
|||
[INFO] Total time: 15.482 s
|
||||
[INFO] Finished at: 2021-03-06T15:27:35+09:00
|
||||
[INFO] ------------------------------------------------------------------------
|
||||
[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist.
|
||||
[WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
|
||||
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
|
||||
[ERROR]
|
||||
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
Const,in_0
|
||||
Const,while/Const
|
||||
Const,while/add/y
|
||||
Identity,in_0/read
|
||||
Enter,while/Enter
|
||||
Enter,while/Enter_1
|
||||
Merge,while/Merge
|
||||
Merge,while/Merge_1
|
||||
Less,while/Less
|
||||
LoopCond,while/LoopCond
|
||||
Switch,while/Switch
|
||||
Switch,while/Switch_1
|
||||
Identity,while/Identity
|
||||
Exit,while/Exit
|
||||
Identity,while/Identity_1
|
||||
Exit,while/Exit_1
|
||||
Add,while/add
|
||||
NextIteration,while/NextIteration_1
|
||||
NextIteration,while/NextIteration
|
|
@ -0,0 +1,16 @@
|
|||
Identity,in_0/read
|
||||
Enter,while/Enter
|
||||
Enter,while/Enter_1
|
||||
Merge,while/Merge
|
||||
Merge,while/Merge_1
|
||||
Less,while/Less
|
||||
LoopCond,while/LoopCond
|
||||
Switch,while/Switch
|
||||
Switch,while/Switch_1
|
||||
Identity,while/Identity
|
||||
Exit,while/Exit
|
||||
Identity,while/Identity_1
|
||||
Exit,while/Exit_1
|
||||
Add,while/add
|
||||
NextIteration,while/NextIteration_1
|
||||
NextIteration,while/NextIteration
|
|
@ -0,0 +1,19 @@
|
|||
in_0
|
||||
while/Const
|
||||
while/add/y
|
||||
in_0/read
|
||||
while/Enter
|
||||
while/Enter_1
|
||||
while/Merge
|
||||
while/Merge_1
|
||||
while/Less
|
||||
while/LoopCond
|
||||
while/Switch
|
||||
while/Switch_1
|
||||
while/Identity
|
||||
while/Exit
|
||||
while/Identity_1
|
||||
while/Exit_1
|
||||
while/add
|
||||
while/NextIteration_1
|
||||
while/NextIteration
|
|
@ -303,7 +303,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine>-Dfile.encoding=UTF-8 "</argLine>
|
||||
<argLine>-Dfile.encoding=UTF-8 </argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -27,6 +27,7 @@ import java.util.List;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInfo;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -482,7 +483,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testConv3d(Nd4jBackend backend) {
|
||||
public void testConv3d(Nd4jBackend backend, TestInfo testInfo) {
|
||||
//Pooling3d, Conv3D, batch norm
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -573,7 +574,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
tc.testName(msg);
|
||||
String error = OpValidation.validate(tc);
|
||||
if (error != null) {
|
||||
failed.add(name);
|
||||
failed.add(testInfo.getTestMethod().get().getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1353,7 +1354,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
assertNull(err, err);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
int nIn = 3;
|
||||
|
@ -1382,7 +1384,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -1405,7 +1408,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
|
|
@ -22,6 +22,7 @@ package org.nd4j.autodiff.opvalidation;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
@ -664,6 +665,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void testMmulGradientManual(Nd4jBackend backend) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
|
|
|
@ -69,7 +69,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
@AfterEach
|
||||
public void tearDown(Nd4jBackend backend) {
|
||||
public void tearDown() {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ import lombok.val;
|
|||
import org.apache.commons.math3.linear.LUDecomposition;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInfo;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.OpValidationSuite;
|
||||
|
@ -83,7 +84,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testConcat(Nd4jBackend backend) {
|
||||
public void testConcat(Nd4jBackend backend, TestInfo testInfo) {
|
||||
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
|
||||
int[] concatDim = new int[]{0, 0, 0};
|
||||
List<List<int[]>> origShapes = new ArrayList<>();
|
||||
|
@ -115,7 +116,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
String error = OpValidation.validate(tc);
|
||||
if(error != null){
|
||||
failed.add(name);
|
||||
failed.add(testInfo.getTestMethod().get().getName());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -285,7 +286,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSqueezeGradient(Nd4jBackend backend) {
|
||||
public void testSqueezeGradient(Nd4jBackend backend,TestInfo testInfo) {
|
||||
val origShape = new long[]{3, 4, 5};
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
@ -339,7 +340,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
String error = OpValidation.validate(tc, true);
|
||||
if(error != null){
|
||||
failed.add(name);
|
||||
failed.add(testInfo.getTestMethod().get().getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -580,8 +581,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
return Long.MAX_VALUE;
|
||||
}
|
||||
|
||||
@Test()
|
||||
public void testStack(Nd4jBackend backend) {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testStack(Nd4jBackend backend,TestInfo testInfo) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
@ -661,7 +663,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
String error = OpValidation.validate(tc);
|
||||
if(error != null){
|
||||
failed.add(name);
|
||||
failed.add(testInfo.getTestMethod().get().getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -72,6 +72,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
@Slf4j
|
||||
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
@TempDir Path testDir;
|
||||
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
|
@ -82,7 +84,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testBasic(Nd4jBackend backend) throws Exception {
|
||||
SameDiff sd = SameDiff.create();
|
||||
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
|
||||
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
|
||||
|
@ -121,7 +123,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
int numOutputs = fg.outputsLength();
|
||||
List<IntPair> outputs = new ArrayList<>(numOutputs);
|
||||
for( int i=0; i<numOutputs; i++ ){
|
||||
for( int i = 0; i < numOutputs; i++) {
|
||||
outputs.add(fg.outputs(i));
|
||||
}
|
||||
|
||||
|
@ -138,7 +140,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testSimple(Nd4jBackend backend) throws Exception {
|
||||
for( int i = 0; i < 10; i++ ) {
|
||||
for(boolean execFirst : new boolean[]{false, true}) {
|
||||
log.info("Starting test: i={}, execFirst={}", i, execFirst);
|
||||
|
@ -268,7 +270,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testTrainingSerde(Nd4jBackend backend) throws Exception {
|
||||
|
||||
//Ensure 2 things:
|
||||
//1. Training config is serialized/deserialized correctly
|
||||
|
|
|
@ -109,7 +109,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
@BeforeEach
|
||||
public void before(Nd4jBackend backend) {
|
||||
public void before() {
|
||||
Nd4j.create(1);
|
||||
initialType = Nd4j.dataType();
|
||||
|
||||
|
@ -118,7 +118,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
@AfterEach
|
||||
public void after(Nd4jBackend backend) {
|
||||
public void after() {
|
||||
Nd4j.setDataType(initialType);
|
||||
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
||||
|
|
|
@ -21,9 +21,6 @@
|
|||
package org.nd4j.autodiff.samediff.listeners;
|
||||
|
||||
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
@ -47,11 +44,11 @@ import java.util.List;
|
|||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||
@TempDir Path testDir;
|
||||
|
||||
|
||||
@Override
|
||||
|
@ -97,7 +94,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testCheckpointEveryEpoch(Nd4jBackend backend) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
|
||||
SameDiff sd = getModel();
|
||||
|
@ -132,7 +129,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testCheckpointEvery5Iter(Nd4jBackend backend) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
|
||||
SameDiff sd = getModel();
|
||||
|
@ -172,7 +169,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testCheckpointListenerEveryTimeUnit(Nd4jBackend backend) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
SameDiff sd = getModel();
|
||||
|
||||
|
@ -217,7 +214,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testCheckpointListenerKeepLast3AndEvery3(Nd4jBackend backend) throws Exception {
|
||||
File dir = testDir.toFile();
|
||||
SameDiff sd = getModel();
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ package org.nd4j.autodiff.samediff.listeners;
|
|||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
@ -49,6 +50,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
|||
|
||||
public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
@TempDir Path testDir;
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -59,7 +61,8 @@ public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
@Disabled
|
||||
public void testProfilingListenerSimple(Nd4jBackend backend) throws Exception {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
|
||||
|
|
|
@ -64,6 +64,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
@Slf4j
|
||||
public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
@TempDir Path testDir;
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
|
@ -81,7 +82,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
|
||||
public void testSimple(Nd4jBackend backend) throws IOException {
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
|
||||
SDVariable sum = v.sum();
|
||||
|
@ -163,7 +164,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
//Append a number of events
|
||||
w.registerEventName("accuracy");
|
||||
for( int iter=0; iter<3; iter++) {
|
||||
for( int iter = 0; iter < 3; iter++) {
|
||||
long t = System.currentTimeMillis();
|
||||
w.writeScalarEvent("accuracy", LogFileWriter.EventSubtype.EVALUATION, t, iter, 0, 0.5 + 0.1 * iter);
|
||||
}
|
||||
|
@ -175,7 +176,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
|||
UIAddName addName = (UIAddName) events.get(0).getRight();
|
||||
assertEquals("accuracy", addName.name());
|
||||
|
||||
for( int i=1; i<4; i++ ){
|
||||
for( int i = 1; i < 4; i++ ){
|
||||
FlatArray fa = (FlatArray) events.get(i).getRight();
|
||||
INDArray arr = Nd4j.createFromFlatArray(fa);
|
||||
|
||||
|
@ -186,7 +187,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
|
||||
public void testNullBinLabels(Nd4jBackend backend) throws Exception{
|
||||
File dir = testDir.toFile();
|
||||
File f = new File(dir, "temp.bin");
|
||||
LogFileWriter w = new LogFileWriter(f);
|
||||
|
|
|
@ -56,6 +56,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
@TempDir Path testDir;
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
@ -65,7 +67,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testUIListenerBasic(Nd4jBackend backend) throws Exception {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
|
@ -102,7 +104,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testUIListenerContinue(Nd4jBackend backend) throws Exception {
|
||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
|
||||
SameDiff sd1 = getSimpleNet();
|
||||
|
@ -194,7 +196,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testUIListenerBadContinue(Nd4jBackend backend) throws Exception {
|
||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
SameDiff sd1 = getSimpleNet();
|
||||
|
||||
|
@ -275,7 +277,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
private static SameDiff getSimpleNet(){
|
||||
private static SameDiff getSimpleNet() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
||||
|
|
|
@ -22,6 +22,7 @@ package org.nd4j.evaluation;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
@ -48,6 +49,7 @@ public class NewInstanceTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void testNewInstances(Nd4jBackend backend) {
|
||||
boolean print = true;
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
|
||||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.evaluation.classification.ROC;
|
||||
|
@ -42,14 +42,15 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void testROCBinary(Nd4jBackend backend) {
|
||||
//Compare ROCBinary to ROC class
|
||||
|
||||
|
@ -144,7 +145,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testRocBinaryMerging(Nd4jBackend backend) {
|
||||
for (int nSteps : new int[]{30, 0}) { //0 == exact
|
||||
|
@ -175,7 +176,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
|
||||
|
||||
|
@ -216,7 +217,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testROCBinary3d(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||
|
@ -251,7 +252,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testROCBinary4d(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
|
@ -286,7 +287,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testROCBinary3dMasking(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||
|
@ -348,7 +349,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testROCBinary4dMasking(Nd4jBackend backend) {
|
||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
@ -82,16 +83,16 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
expFPR.put(10 / 10.0, 0.0 / totalNegatives);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testRocBasic(Nd4jBackend backend) {
|
||||
//2 outputs here - probability distribution over classes (softmax)
|
||||
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
||||
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
||||
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
||||
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
||||
|
||||
INDArray actual = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
|
||||
{0, 1}, {0, 1}});
|
||||
{0, 1}, {0, 1}});
|
||||
|
||||
ROC roc = new ROC(10);
|
||||
roc.eval(actual, predictions);
|
||||
|
@ -126,15 +127,15 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
assertEquals(1.0, auc, 1e-6);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testRocBasicSingleClass(Nd4jBackend backend) {
|
||||
//1 output here - single probability value (sigmoid)
|
||||
|
||||
//add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||
INDArray predictions =
|
||||
Nd4j.create(new double[] {0.001, 0.101, 0.201, 0.301, 0.401, 0.501, 0.601, 0.701, 0.801, 0.901},
|
||||
new int[] {10, 1});
|
||||
Nd4j.create(new double[] {0.001, 0.101, 0.201, 0.301, 0.401, 0.501, 0.601, 0.701, 0.801, 0.901},
|
||||
new int[] {10, 1});
|
||||
|
||||
INDArray actual = Nd4j.create(new double[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}, new int[] {10, 1});
|
||||
|
||||
|
@ -165,7 +166,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testRoc(Nd4jBackend backend) {
|
||||
//Previous tests allowed for a perfect classifier with right threshold...
|
||||
|
@ -173,7 +174,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});
|
||||
|
||||
INDArray prediction = Nd4j.create(new double[][] {{0.199, 0.801}, {0.499, 0.501}, {0.399, 0.601},
|
||||
{0.799, 0.201}, {0.899, 0.101}});
|
||||
{0.799, 0.201}, {0.899, 0.101}});
|
||||
|
||||
Map<Double, Double> expTPR = new HashMap<>();
|
||||
double totalPositives = 2.0;
|
||||
|
@ -251,27 +252,27 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
|
||||
//Same as first test...
|
||||
|
||||
//2 outputs here - probability distribution over classes (softmax)
|
||||
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
||||
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
||||
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
||||
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
||||
|
||||
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
|
||||
{0, 1}, {0, 1}});
|
||||
{0, 1}, {0, 1}});
|
||||
|
||||
INDArray predictions3d = Nd4j.create(2, 2, 5);
|
||||
INDArray firstTSp =
|
||||
predictions3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
|
||||
predictions3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
|
||||
assertArrayEquals(new long[] {5, 2}, firstTSp.shape());
|
||||
firstTSp.assign(predictions2d.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
|
||||
|
||||
INDArray secondTSp =
|
||||
predictions3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
|
||||
predictions3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
|
||||
assertArrayEquals(new long[] {5, 2}, secondTSp.shape());
|
||||
secondTSp.assign(predictions2d.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
|
||||
|
||||
|
@ -299,23 +300,23 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
|
||||
//2 outputs here - probability distribution over classes (softmax)
|
||||
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
||||
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
||||
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
||||
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
||||
|
||||
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
|
||||
{0, 1}, {0, 1}});
|
||||
{0, 1}, {0, 1}});
|
||||
|
||||
|
||||
//Create time series data... first time series: length 4. Second time series: length 6
|
||||
INDArray predictions3d = Nd4j.create(2, 2, 6);
|
||||
INDArray tad = predictions3d.tensorAlongDimension(0, 1, 2).transpose();
|
||||
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
|
||||
.assign(predictions2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
|
||||
.assign(predictions2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
|
||||
|
||||
tad = predictions3d.tensorAlongDimension(1, 1, 2).transpose();
|
||||
tad.assign(predictions2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
|
||||
|
@ -324,7 +325,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
INDArray labels3d = Nd4j.create(2, 2, 6);
|
||||
tad = labels3d.tensorAlongDimension(0, 1, 2).transpose();
|
||||
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
|
||||
.assign(actual2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
|
||||
.assign(actual2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
|
||||
|
||||
tad = labels3d.tensorAlongDimension(1, 1, 2).transpose();
|
||||
tad.assign(actual2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
|
||||
|
@ -350,7 +351,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -381,7 +382,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testCompare2Vs3Classes(Nd4jBackend backend) {
|
||||
|
||||
|
@ -431,7 +432,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testROCMerging(Nd4jBackend backend) {
|
||||
int nArrays = 10;
|
||||
|
@ -477,7 +478,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testROCMerging2(Nd4jBackend backend) {
|
||||
int nArrays = 10;
|
||||
|
@ -523,7 +524,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testROCMultiMerging(Nd4jBackend backend) {
|
||||
|
||||
|
@ -572,7 +573,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testAUCPrecisionRecall(Nd4jBackend backend) {
|
||||
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
|
||||
|
@ -620,7 +621,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testRocAucExact(Nd4jBackend backend) {
|
||||
|
||||
|
@ -681,20 +682,20 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
*/
|
||||
|
||||
double[] p = new double[] {0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503, 0.5955447, 0.96451452,
|
||||
0.6531771, 0.74890664, 0.65356987, 0.74771481, 0.96130674, 0.0083883, 0.10644438, 0.29870371,
|
||||
0.65641118, 0.80981255, 0.87217591, 0.9646476, 0.72368535, 0.64247533, 0.71745362, 0.46759901,
|
||||
0.32558468, 0.43964461, 0.72968908, 0.99401459, 0.67687371, 0.79082252, 0.17091426};
|
||||
0.6531771, 0.74890664, 0.65356987, 0.74771481, 0.96130674, 0.0083883, 0.10644438, 0.29870371,
|
||||
0.65641118, 0.80981255, 0.87217591, 0.9646476, 0.72368535, 0.64247533, 0.71745362, 0.46759901,
|
||||
0.32558468, 0.43964461, 0.72968908, 0.99401459, 0.67687371, 0.79082252, 0.17091426};
|
||||
|
||||
double[] l = new double[] {1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
|
||||
0, 1};
|
||||
0, 1};
|
||||
|
||||
double[] fpr_skl = new double[] {0.0, 0.0, 0.15789474, 0.15789474, 0.31578947, 0.31578947, 0.52631579,
|
||||
0.52631579, 0.68421053, 0.68421053, 0.84210526, 0.84210526, 0.89473684, 0.89473684, 1.0};
|
||||
0.52631579, 0.68421053, 0.68421053, 0.84210526, 0.84210526, 0.89473684, 0.89473684, 1.0};
|
||||
double[] tpr_skl = new double[] {0.0, 0.09090909, 0.09090909, 0.18181818, 0.18181818, 0.36363636, 0.36363636,
|
||||
0.45454545, 0.45454545, 0.72727273, 0.72727273, 0.90909091, 0.90909091, 1.0, 1.0};
|
||||
0.45454545, 0.45454545, 0.72727273, 0.72727273, 0.90909091, 0.90909091, 1.0, 1.0};
|
||||
//Note the change to the last value: same TPR and FPR at 0.0083883 and 0.0 -> we add the 0.0 threshold edge case + combine with the previous one. Same result
|
||||
double[] thr_skl = new double[] {1.0, 0.99401459, 0.96130674, 0.92961609, 0.79082252, 0.74771481, 0.67687371,
|
||||
0.65641118, 0.64247533, 0.46759901, 0.31637555, 0.20456028, 0.18391881, 0.17091426, 0.0};
|
||||
0.65641118, 0.64247533, 0.46759901, 0.31637555, 0.20456028, 0.18391881, 0.17091426, 0.0};
|
||||
|
||||
INDArray prob = Nd4j.create(p, new int[] {30, 1});
|
||||
INDArray label = Nd4j.create(l, new int[] {30, 1});
|
||||
|
@ -784,7 +785,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
|
||||
|
||||
|
@ -797,7 +798,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
|
||||
double[] threshold = new double[101];
|
||||
|
@ -814,15 +815,15 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
PrecisionRecallCurve prc = new PrecisionRecallCurve(threshold, precision, recall, null, null, null, -1);
|
||||
|
||||
PrecisionRecallCurve.Point[] points = new PrecisionRecallCurve.Point[] {
|
||||
//Test exact:
|
||||
prc.getPointAtThreshold(0.05), prc.getPointAtPrecision(0.05), prc.getPointAtRecall(1 - 0.05),
|
||||
//Test exact:
|
||||
prc.getPointAtThreshold(0.05), prc.getPointAtPrecision(0.05), prc.getPointAtRecall(1 - 0.05),
|
||||
|
||||
//Test approximate (point doesn't exist exactly). When it doesn't exist:
|
||||
//Threshold: lowest threshold equal to or exceeding the specified threshold value
|
||||
//Precision: lowest threshold equal to or exceeding the specified precision value
|
||||
//Recall: highest threshold equal to or exceeding the specified recall value
|
||||
prc.getPointAtThreshold(0.0495), prc.getPointAtPrecision(0.0495),
|
||||
prc.getPointAtRecall(1 - 0.0505)};
|
||||
//Test approximate (point doesn't exist exactly). When it doesn't exist:
|
||||
//Threshold: lowest threshold equal to or exceeding the specified threshold value
|
||||
//Precision: lowest threshold equal to or exceeding the specified precision value
|
||||
//Recall: highest threshold equal to or exceeding the specified recall value
|
||||
prc.getPointAtThreshold(0.0495), prc.getPointAtPrecision(0.0495),
|
||||
prc.getPointAtRecall(1 - 0.0505)};
|
||||
|
||||
|
||||
|
||||
|
@ -834,7 +835,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
|
||||
//Sanity check: values calculated from the confusion matrix should match the PR curve values
|
||||
|
@ -843,7 +844,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
ROC r = new ROC(0, removeRedundantPts);
|
||||
|
||||
INDArray labels = Nd4j.getExecutioner()
|
||||
.exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE,100, 1), 0.5));
|
||||
.exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE,100, 1), 0.5));
|
||||
INDArray probs = Nd4j.rand(100, 1);
|
||||
|
||||
r.eval(labels, probs);
|
||||
|
@ -874,7 +875,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testRocMerge(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -919,7 +920,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
assertEquals(auprc, auprcAct, 1e-6);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testRocMultiMerge(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -931,9 +932,9 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
int nOut = 5;
|
||||
|
||||
Random r = new Random(12345);
|
||||
for( int i=0; i<10; i++ ){
|
||||
for( int i = 0; i < 10; i++ ){
|
||||
INDArray labels = Nd4j.zeros(3, nOut);
|
||||
for( int j=0; j<3; j++ ){
|
||||
for( int j = 0; j < 3; j++) {
|
||||
labels.putScalar(j, r.nextInt(nOut), 1.0 );
|
||||
}
|
||||
INDArray out = Nd4j.rand(3, nOut);
|
||||
|
@ -956,7 +957,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
roc1.merge(roc2);
|
||||
|
||||
for( int i=0; i<nOut; i++ ) {
|
||||
for( int i = 0; i < nOut; i++) {
|
||||
|
||||
double aucExp = roc.calculateAUC(i);
|
||||
double auprc = roc.calculateAUCPR(i);
|
||||
|
@ -969,9 +970,10 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testRocBinaryMerge(){
|
||||
@Disabled
|
||||
public void testRocBinaryMerge(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
ROCBinary roc = new ROCBinary();
|
||||
|
@ -980,7 +982,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
int nOut = 5;
|
||||
|
||||
for( int i=0; i<10; i++ ){
|
||||
for( int i = 0; i < 10; i++) {
|
||||
INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(3, nOut),0.5));
|
||||
INDArray out = Nd4j.rand(3, nOut);
|
||||
out.diviColumnVector(out.sum(1));
|
||||
|
@ -1015,7 +1017,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSegmentationBinary(){
|
||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||
|
@ -1106,7 +1108,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSegmentation(){
|
||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||
|
|
|
@ -50,7 +50,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
|||
return 'c';
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testEvalParameters(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
int specCols = 5;
|
||||
|
|
|
@ -152,7 +152,7 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
|||
public void maskWhenMerge(Nd4jBackend backend) {
|
||||
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
|
||||
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
|
||||
List<DataSet> dataSetList = new ArrayList<DataSet>();
|
||||
List<DataSet> dataSetList = new ArrayList<>();
|
||||
dataSetList.add(dsA);
|
||||
dataSetList.add(dsB);
|
||||
DataSet fullDataSet = DataSet.merge(dataSetList);
|
||||
|
@ -175,7 +175,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
|||
// System.out.println(b);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
//broken at a threshold
|
||||
public void testArgMax(Nd4jBackend backend) {
|
||||
int max = 63;
|
||||
|
@ -263,7 +264,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
|||
// log.info("p50: {}; avg: {};", times.get(times.size() / 2), time);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void checkIllegalElementOps(Nd4jBackend backend) {
|
||||
assertThrows(Exception.class,() -> {
|
||||
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
|
||||
|
@ -328,13 +330,13 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
|||
reshaped.getDouble(i);
|
||||
}
|
||||
for (int j=0;j<arr.slices();j++) {
|
||||
for (int k=0;k<arr.slice(j).length();k++) {
|
||||
for (int k = 0; k < arr.slice(j).length(); k++) {
|
||||
// log.info("\nArr: slice " + j + " element " + k + " " + arr.slice(j).getDouble(k));
|
||||
arr.slice(j).getDouble(k);
|
||||
}
|
||||
}
|
||||
for (int j=0;j<reshaped.slices();j++) {
|
||||
for (int k=0;k<reshaped.slice(j).length();k++) {
|
||||
for (int j = 0;j < reshaped.slices(); j++) {
|
||||
for (int k = 0;k < reshaped.slice(j).length(); k++) {
|
||||
// log.info("\nReshaped: slice " + j + " element " + k + " " + reshaped.slice(j).getDouble(k));
|
||||
reshaped.slice(j).getDouble(k);
|
||||
}
|
||||
|
|
|
@ -245,7 +245,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, false);
|
||||
assertEquals(sorted[1], sorted2);
|
||||
INDArray shouldIndex = Nd4j.create(new double[] {1, 1, 0, 0}, new long[] {2, 2});
|
||||
assertEquals(shouldIndex, sorted[0],getFailureMessage());
|
||||
assertEquals(shouldIndex, sorted[0],getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -266,7 +266,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, true);
|
||||
assertEquals(sorted[1], sorted2);
|
||||
INDArray shouldIndex = Nd4j.create(new double[] {0, 0, 1, 1}, new long[] {2, 2});
|
||||
assertEquals(shouldIndex, sorted[0],getFailureMessage());
|
||||
assertEquals(shouldIndex, sorted[0],getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -328,13 +328,13 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
public void testDivide(Nd4jBackend backend) {
|
||||
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
|
||||
INDArray div = two.div(two);
|
||||
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage());
|
||||
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage(backend));
|
||||
|
||||
INDArray half = Nd4j.create(new float[] {0.5f, 0.5f, 0.5f, 0.5f}, new long[] {2, 2});
|
||||
INDArray divi = Nd4j.create(new float[] {0.3f, 0.6f, 0.9f, 0.1f}, new long[] {2, 2});
|
||||
INDArray assertion = Nd4j.create(new float[] {1.6666666f, 0.8333333f, 0.5555556f, 5}, new long[] {2, 2});
|
||||
INDArray result = half.div(divi);
|
||||
assertEquals( assertion, result,getFailureMessage());
|
||||
assertEquals( assertion, result,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
|
||||
|
@ -344,7 +344,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
|
||||
INDArray sigmoid = Transforms.sigmoid(n, false);
|
||||
assertEquals( assertion, sigmoid,getFailureMessage());
|
||||
assertEquals( assertion, sigmoid,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -354,7 +354,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
||||
INDArray neg = Transforms.neg(n);
|
||||
assertEquals(assertion, neg,getFailureMessage());
|
||||
assertEquals(assertion, neg,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -365,12 +365,12 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||
double sim = Transforms.cosineSim(vec1, vec2);
|
||||
assertEquals(1, sim, 1e-1,getFailureMessage());
|
||||
assertEquals(1, sim, 1e-1,getFailureMessage(backend));
|
||||
|
||||
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
|
||||
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
|
||||
sim = Transforms.cosineSim(vec3, vec4);
|
||||
assertEquals(0.98, sim, 1e-1,getFailureMessage());
|
||||
assertEquals(0.98, sim, 1e-1,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -621,7 +621,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
INDArray innerProduct = n.mmul(transposed);
|
||||
|
||||
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
|
||||
assertEquals(scalar, innerProduct,getFailureMessage());
|
||||
assertEquals(scalar, innerProduct,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
|
||||
|
@ -678,7 +678,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
INDArray five = Nd4j.ones(5);
|
||||
five.addi(five.dup());
|
||||
INDArray twos = Nd4j.valueArrayOf(5, 2);
|
||||
assertEquals(twos, five,getFailureMessage());
|
||||
assertEquals(twos, five,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -692,7 +692,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
||||
|
||||
INDArray test = arr.mmul(arr.transpose());
|
||||
assertEquals(assertion, test,getFailureMessage());
|
||||
assertEquals(assertion, test,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -704,7 +704,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
Nd4j.exec(new PrintVariable(newSlice));
|
||||
log.info("Slice: {}", newSlice);
|
||||
n.putSlice(0, newSlice);
|
||||
assertEquals( newSlice, n.slice(0),getFailureMessage());
|
||||
assertEquals( newSlice, n.slice(0),getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -713,7 +713,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
|
||||
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
|
||||
linear.putScalar(new long[] {0, 1}, 1);
|
||||
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage());
|
||||
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
|
||||
|
@ -1059,7 +1059,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
|||
INDArray nClone = n1.add(n2);
|
||||
assertEquals(Nd4j.scalar(3), nClone);
|
||||
INDArray n1PlusN2 = n1.add(n2);
|
||||
assertFalse(n1PlusN2.equals(n1),getFailureMessage());
|
||||
assertFalse(n1PlusN2.equals(n1),getFailureMessage(backend));
|
||||
|
||||
INDArray n3 = Nd4j.scalar(3.0);
|
||||
INDArray n4 = Nd4j.scalar(4.0);
|
||||
|
|
|
@ -156,7 +156,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
DataType initialType = Nd4j.dataType();
|
||||
Level1 l1 = Nd4j.getBlasWrapper().level1();
|
||||
|
||||
@TempDir Path testDir;
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
|
@ -255,7 +255,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSerialization(@TempDir Path testDir) throws Exception {
|
||||
public void testSerialization(Nd4jBackend backend) throws Exception {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray arr = Nd4j.rand(1, 20);
|
||||
|
||||
|
@ -339,7 +339,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
assertArrayEquals(assertion,shapeTest);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled //temporary till libnd4j implements general broadcasting
|
||||
public void testAutoBroadcastAdd(Nd4jBackend backend) {
|
||||
INDArray left = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,1,2,1);
|
||||
|
@ -370,9 +371,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
n.divi(Nd4j.scalar(1.0d));
|
||||
|
||||
n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
|
||||
assertEquals(27, n.sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
||||
assertEquals(27, n.sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||
INDArray a = n.slice(2);
|
||||
assertEquals( true, Arrays.equals(new long[] {3, 3}, a.shape()),getFailureMessage());
|
||||
assertEquals( true, Arrays.equals(new long[] {3, 3}, a.shape()),getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -478,12 +479,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
||||
|
||||
INDArray test = arr.mmul(arr.transpose());
|
||||
assertEquals(assertion, test,getFailureMessage());
|
||||
assertEquals(assertion, test,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void testMmulOp() throws Exception {
|
||||
public void testMmulOp(Nd4jBackend backend) throws Exception {
|
||||
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
||||
INDArray z = Nd4j.create(2, 2);
|
||||
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
||||
|
@ -494,7 +496,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
DynamicCustomOp op = new Mmul(arr, arr, z, mMulTranspose);
|
||||
Nd4j.getExecutioner().execAndReturn(op);
|
||||
|
||||
assertEquals(assertion, z,getFailureMessage());
|
||||
assertEquals(assertion, z,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
|
||||
|
@ -505,7 +507,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
INDArray row1 = oneThroughFour.getRow(1).dup();
|
||||
oneThroughFour.subiRowVector(row1);
|
||||
INDArray result = Nd4j.create(new double[] {-2, -2, 0, 0}, new long[] {2, 2});
|
||||
assertEquals(result, oneThroughFour,getFailureMessage());
|
||||
assertEquals(result, oneThroughFour,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -1093,7 +1095,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
assertTrue(expAllOnes.all());
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void testSumAlongDim1sEdgeCases(Nd4jBackend backend) {
|
||||
val shapes = new long[][] {
|
||||
|
@ -1227,7 +1230,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
INDArray row1 = oneThroughFour.getRow(1);
|
||||
row1.addi(1);
|
||||
INDArray result = Nd4j.create(new double[] {1, 2, 4, 5}, new long[] {2, 2});
|
||||
assertEquals(result, oneThroughFour,getFailureMessage());
|
||||
assertEquals(result, oneThroughFour,getFailureMessage(backend));
|
||||
|
||||
|
||||
}
|
||||
|
@ -1241,8 +1244,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
INDArray linear = test.reshape(-1);
|
||||
linear.putScalar(2, 6);
|
||||
linear.putScalar(3, 7);
|
||||
assertEquals(6, linear.getFloat(2), 1e-1,getFailureMessage());
|
||||
assertEquals(7, linear.getFloat(3), 1e-1,getFailureMessage());
|
||||
assertEquals(6, linear.getFloat(2), 1e-1,getFailureMessage(backend));
|
||||
assertEquals(7, linear.getFloat(3), 1e-1,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
|
||||
|
@ -1609,7 +1612,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
||||
INDArray neg = Transforms.neg(n);
|
||||
assertEquals(assertion, neg,getFailureMessage());
|
||||
assertEquals(assertion, neg,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -1622,13 +1625,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||
double assertion = 5.47722557505;
|
||||
double norm3 = n.norm2Number().doubleValue();
|
||||
assertEquals(assertion, norm3, 1e-1,getFailureMessage());
|
||||
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
|
||||
|
||||
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||
INDArray row1 = row.getRow(1);
|
||||
double norm2 = row1.norm2Number().doubleValue();
|
||||
double assertion2 = 5.0f;
|
||||
assertEquals(assertion2, norm2, 1e-1,getFailureMessage());
|
||||
assertEquals(assertion2, norm2, 1e-1,getFailureMessage(backend));
|
||||
|
||||
Nd4j.setDataType(initialType);
|
||||
}
|
||||
|
@ -1640,14 +1643,14 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||
float assertion = 5.47722557505f;
|
||||
float norm3 = n.norm2Number().floatValue();
|
||||
assertEquals(assertion, norm3, 1e-1,getFailureMessage());
|
||||
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
|
||||
|
||||
|
||||
INDArray row = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||
INDArray row1 = row.getRow(1);
|
||||
float norm2 = row1.norm2Number().floatValue();
|
||||
float assertion2 = 5.0f;
|
||||
assertEquals(assertion2, norm2, 1e-1,getFailureMessage());
|
||||
assertEquals(assertion2, norm2, 1e-1,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -1659,7 +1662,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||
double sim = Transforms.cosineSim(vec1, vec2);
|
||||
assertEquals(1, sim, 1e-1,getFailureMessage());
|
||||
assertEquals(1, sim, 1e-1,getFailureMessage(backend));
|
||||
|
||||
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
|
||||
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
|
||||
|
@ -1675,14 +1678,14 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
double assertion = 2;
|
||||
INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8});
|
||||
INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer);
|
||||
assertEquals(answer, scal,getFailureMessage());
|
||||
assertEquals(answer, scal,getFailureMessage(backend));
|
||||
|
||||
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||
INDArray row1 = row.getRow(1);
|
||||
double assertion2 = 5.0;
|
||||
INDArray answer2 = Nd4j.create(new double[] {15, 20});
|
||||
INDArray scal2 = Nd4j.getBlasWrapper().scal(assertion2, row1);
|
||||
assertEquals(answer2, scal2,getFailureMessage());
|
||||
assertEquals(answer2, scal2,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -2076,17 +2079,17 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
INDArray innerProduct = n.mmul(transposed);
|
||||
|
||||
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
|
||||
assertEquals(scalar, innerProduct,getFailureMessage());
|
||||
assertEquals(scalar, innerProduct,getFailureMessage(backend));
|
||||
|
||||
INDArray outerProduct = transposed.mmul(n);
|
||||
assertEquals(true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape()),getFailureMessage());
|
||||
assertEquals(true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape()),getFailureMessage(backend));
|
||||
|
||||
|
||||
|
||||
INDArray three = Nd4j.create(new double[] {3, 4});
|
||||
INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2});
|
||||
INDArray sliceRow = test.slice(0).getRow(1);
|
||||
assertEquals(three, sliceRow,getFailureMessage());
|
||||
assertEquals(three, sliceRow,getFailureMessage(backend));
|
||||
|
||||
INDArray twoSix = Nd4j.create(new double[] {2, 6}, new long[] {2, 1});
|
||||
INDArray threeTwoSix = three.mmul(twoSix);
|
||||
|
@ -2114,7 +2117,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
INDArray k1 = n1.transpose();
|
||||
|
||||
INDArray testVectorVector = k1.mmul(n1);
|
||||
assertEquals(vectorVector, testVectorVector,getFailureMessage());
|
||||
assertEquals(vectorVector, testVectorVector,getFailureMessage(backend));
|
||||
|
||||
|
||||
}
|
||||
|
@ -2204,7 +2207,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(linear.getDouble(0, 1), 1, 1e-1);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSize(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
INDArray arr = Nd4j.create(4, 5);
|
||||
|
@ -2357,7 +2361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void testTensorDot(Nd4jBackend backend) {
|
||||
INDArray oneThroughSixty = Nd4j.arange(60).reshape(3, 4, 5).castTo(DataType.DOUBLE);
|
||||
|
@ -3051,10 +3056,10 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
public void testMeans(Nd4jBackend backend) {
|
||||
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
INDArray mean1 = a.mean(1);
|
||||
assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage());
|
||||
assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage());
|
||||
assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage());
|
||||
assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage());
|
||||
assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage(backend));
|
||||
assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage(backend));
|
||||
assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||
assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -3063,9 +3068,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSums(Nd4jBackend backend) {
|
||||
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage());
|
||||
assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage());
|
||||
assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
||||
assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage(backend));
|
||||
assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage(backend));
|
||||
assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||
|
||||
|
||||
}
|
||||
|
@ -3438,7 +3443,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void largeInstantiation(Nd4jBackend backend) {
|
||||
Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024 - 1)); // Still works; this can even be called as often as I want, allowing me even to spill over on disk
|
||||
|
@ -3487,7 +3493,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(cSum, fSum); //Expect: 4,6. Getting [4, 4] for f order
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled //not relevant anymore
|
||||
public void testAssignMixedC(Nd4jBackend backend) {
|
||||
int[] shape1 = {3, 2, 2, 2, 2, 2};
|
||||
|
@ -3787,7 +3794,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(assertion, result);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPullRowsValidation1(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2});
|
||||
|
@ -3795,7 +3803,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPullRowsValidation2(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2});
|
||||
|
@ -3803,7 +3812,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPullRowsValidation3(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10});
|
||||
|
@ -3811,7 +3821,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPullRowsValidation4(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3});
|
||||
|
@ -3819,7 +3830,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPullRowsValidation5(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e');
|
||||
|
@ -4975,7 +4987,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testTadReduce3_5(Nd4jBackend backend) {
|
||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||
INDArray initial = Nd4j.create(5, 10);
|
||||
|
@ -6004,7 +6017,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void testLogExpSum1(Nd4jBackend backend) {
|
||||
INDArray matrix = Nd4j.create(3, 3);
|
||||
|
@ -6019,7 +6033,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void testLogExpSum2(Nd4jBackend backend) {
|
||||
INDArray row = Nd4j.create(new double[]{1, 2, 3});
|
||||
|
@ -6246,7 +6261,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testReshapeFailure(Nd4jBackend backend) {
|
||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||
val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2);
|
||||
|
@ -6345,7 +6361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
assertArrayEquals(new long[]{3, 2}, newShape.shape());
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testTranspose1(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
|
||||
|
@ -6360,7 +6377,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testTranspose2(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
val scalar = Nd4j.scalar(2.f);
|
||||
|
@ -6375,7 +6393,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
//@Disabled
|
||||
public void testMatmul_128by256(Nd4jBackend backend) {
|
||||
val mA = Nd4j.create(128, 156).assign(1.0f);
|
||||
|
@ -6647,7 +6666,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(exp1, out1);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testBadReduce3Call(Nd4jBackend backend) {
|
||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||
val x = Nd4j.create(400,20);
|
||||
|
@ -7392,8 +7412,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(ez, z);
|
||||
}
|
||||
|
||||
@Test()
|
||||
public void testBroadcastInvalid(){
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testBroadcastInvalid() {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
INDArray arr1 = Nd4j.ones(3,4,1);
|
||||
|
||||
|
@ -7656,7 +7677,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(exp, array);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testScatterUpdateShortcut_f1(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
val array = Nd4j.create(DataType.FLOAT, 5, 2);
|
||||
|
@ -8041,7 +8063,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(exp, out); //Failing here
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPullRowsFailure(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
val idxs = new int[]{0,2,3,4};
|
||||
|
@ -8144,7 +8167,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(exp1, out1); //This is OK
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPutRowValidation(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
val matrix = Nd4j.create(5, 10);
|
||||
|
@ -8155,7 +8179,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPutColumnValidation(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
val matrix = Nd4j.create(5, 10);
|
||||
|
@ -8236,7 +8261,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testScalarEq(){
|
||||
public void testScalarEq(Nd4jBackend backend){
|
||||
INDArray scalarRank2 = Nd4j.scalar(10.0).reshape(1,1);
|
||||
INDArray scalarRank1 = Nd4j.scalar(10.0).reshape(1);
|
||||
INDArray scalarRank0 = Nd4j.scalar(10.0);
|
||||
|
@ -8273,7 +8298,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testType1(@TempDir Path testDir) throws IOException {
|
||||
@Disabled
|
||||
public void testType1(Nd4jBackend backend) throws IOException {
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100});
|
||||
File dir = testDir.toFile();
|
||||
|
@ -8295,7 +8321,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testOnes(){
|
||||
public void testOnes(Nd4jBackend backend){
|
||||
INDArray arr = Nd4j.ones();
|
||||
INDArray arr2 = Nd4j.ones(DataType.LONG);
|
||||
assertEquals(0, arr.rank());
|
||||
|
@ -8306,7 +8332,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testZeros(){
|
||||
public void testZeros(Nd4jBackend backend){
|
||||
INDArray arr = Nd4j.zeros();
|
||||
INDArray arr2 = Nd4j.zeros(DataType.LONG);
|
||||
assertEquals(0, arr.rank());
|
||||
|
@ -8317,7 +8343,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testType2(@TempDir Path testDir) throws IOException {
|
||||
@Disabled
|
||||
public void testType2(Nd4jBackend backend) throws IOException {
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
INDArray in1 = Nd4j.ones(DataType.UINT16);
|
||||
File dir = testDir.toFile();
|
||||
|
|
|
@ -23,6 +23,7 @@ package org.nd4j.linalg;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
@ -58,11 +59,12 @@ public class ToStringTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testToStringScalars(){
|
||||
@Disabled
|
||||
public void testToStringScalars(Nd4jBackend backend){
|
||||
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
|
||||
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};
|
||||
|
||||
for(int dt=0; dt<5; dt++ ) {
|
||||
for(int dt = 0; dt < 5; dt++) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
long[] shape = ArrayUtil.nTimes(i, 1L);
|
||||
INDArray scalar = Nd4j.scalar(1.0f).castTo(dataTypes[dt]).reshape(shape);
|
||||
|
|
|
@ -64,7 +64,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
|
@ -79,7 +78,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
|
@ -100,7 +98,8 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testCreateNpy3(Nd4jBackend backend) throws Exception {
|
||||
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
|
||||
assertEquals(8, arrCreate.length());
|
||||
|
@ -111,8 +110,9 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
|||
assertEquals(arrCreate.data().address(), pointer.address());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled // this is endless test
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testEndlessAllocation(Nd4jBackend backend) {
|
||||
Nd4j.getEnvironment().setMaxSpecialMemory(1);
|
||||
while (true) {
|
||||
|
@ -121,9 +121,10 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time")
|
||||
public void testAllocationLimits() throws Exception {
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testAllocationLimits(Nd4jBackend backend) throws Exception {
|
||||
Nd4j.create(1);
|
||||
|
||||
val origDeviceLimit = Nd4j.getEnvironment().getDeviceLimit(0);
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
|
||||
package org.nd4j.linalg.api;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
|
|
|
@ -59,7 +59,7 @@ public class Level1Test extends BaseNd4jTestWithBackends {
|
|||
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
INDArray row = matrix.getRow(1);
|
||||
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
|
||||
assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage());
|
||||
assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -70,8 +70,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
|||
/**
|
||||
* Testing level1 blas
|
||||
*/
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testBlasValidation1(Nd4jBackend backend) {
|
||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||
|
@ -89,8 +88,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
|||
/**
|
||||
* Testing level2 blas
|
||||
*/
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testBlasValidation2(Nd4jBackend backend) {
|
||||
assertThrows(RuntimeException.class,() -> {
|
||||
|
@ -109,8 +107,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
|||
/**
|
||||
* Testing level3 blas
|
||||
*/
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testBlasValidation3(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
|
|
|
@ -88,7 +88,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
|||
float[] d1 = new float[] {1, 2, 3, 4};
|
||||
DataBuffer d = Nd4j.createBuffer(d1);
|
||||
float[] d2 = d.asFloat();
|
||||
assertArrayEquals( d1, d2, 1e-1f,getFailureMessage());
|
||||
assertArrayEquals( d1, d2, 1e-1f,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -146,7 +146,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
|||
d.put(0, 0.0);
|
||||
float[] result = new float[] {0, 2, 3, 4};
|
||||
d1 = d.asFloat();
|
||||
assertArrayEquals(d1, result, 1e-1f,getFailureMessage());
|
||||
assertArrayEquals(d1, result, 1e-1f,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
|
||||
|
@ -156,12 +156,12 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
|||
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
||||
float[] get = buffer.getFloatsAt(0, 3);
|
||||
float[] data = new float[] {1, 2, 3};
|
||||
assertArrayEquals(get, data, 1e-1f,getFailureMessage());
|
||||
assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
|
||||
|
||||
|
||||
float[] get2 = buffer.asFloat();
|
||||
float[] allData = buffer.getFloatsAt(0, (int) buffer.length());
|
||||
assertArrayEquals(get2, allData, 1e-1f,getFailureMessage());
|
||||
assertArrayEquals(get2, allData, 1e-1f,getFailureMessage(backend));
|
||||
|
||||
|
||||
}
|
||||
|
@ -173,13 +173,13 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
|||
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
||||
float[] get = buffer.getFloatsAt(1, 3);
|
||||
float[] data = new float[] {2, 3, 4};
|
||||
assertArrayEquals(get, data, 1e-1f,getFailureMessage());
|
||||
assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
|
||||
|
||||
|
||||
float[] allButLast = new float[] {2, 3, 4, 5};
|
||||
|
||||
float[] allData = buffer.getFloatsAt(1, (int) buffer.length());
|
||||
assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage());
|
||||
assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage(backend));
|
||||
|
||||
|
||||
}
|
||||
|
@ -190,7 +190,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
|||
public void testAsBytes(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.create(5);
|
||||
byte[] d = arr.data().asBytes();
|
||||
assertEquals(4 * 5, d.length,getFailureMessage());
|
||||
assertEquals(4 * 5, d.length,getFailureMessage(backend));
|
||||
INDArray rand = Nd4j.rand(3, 3);
|
||||
rand.data().asBytes();
|
||||
|
||||
|
|
|
@ -20,26 +20,18 @@
|
|||
|
||||
package org.nd4j.linalg.api.indexing;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.common.util.ArrayUtil;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||
import org.nd4j.linalg.indexing.IntervalIndex;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndexAll;
|
||||
import org.nd4j.linalg.indexing.NewAxis;
|
||||
import org.nd4j.linalg.indexing.PointIndex;
|
||||
import org.nd4j.linalg.indexing.SpecifiedIndex;
|
||||
import org.nd4j.linalg.indexing.*;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
import org.nd4j.common.util.ArrayUtil;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
|
@ -56,22 +48,22 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testNegativeBounds() {
|
||||
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
|
||||
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
|
||||
INDArray get = arr.get(NDArrayIndex.all(),interval);
|
||||
INDArray assertion = Nd4j.create(new double[][]{
|
||||
{1,2,3},
|
||||
{6,7,8}
|
||||
});
|
||||
assertEquals(assertion,get);
|
||||
public void testNegativeBounds(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
|
||||
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
|
||||
INDArray get = arr.get(NDArrayIndex.all(),interval);
|
||||
INDArray assertion = Nd4j.create(new double[][]{
|
||||
{1,2,3},
|
||||
{6,7,8}
|
||||
});
|
||||
assertEquals(assertion,get);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testNewAxis() {
|
||||
public void testNewAxis(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
|
||||
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
|
||||
long[] shapeAssertion = {3, 2, 1, 1, 2};
|
||||
|
@ -79,9 +71,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void broadcastBug() {
|
||||
public void broadcastBug(Nd4jBackend backend) {
|
||||
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
|
||||
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
||||
|
||||
|
@ -91,9 +83,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testIntervalsIn3D() {
|
||||
public void testIntervalsIn3D(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
||||
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
||||
INDArray rest = arr.get(interval(1, 2), interval(0, 2), interval(0, 2));
|
||||
|
@ -101,9 +93,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSmallInterval() {
|
||||
public void testSmallInterval(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
||||
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
||||
INDArray rest = arr.get(interval(1, 2), all(), all());
|
||||
|
@ -111,9 +103,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testAllWithNewAxisAndInterval() {
|
||||
public void testAllWithNewAxisAndInterval(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
|
||||
|
||||
|
@ -121,9 +113,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(assertion2, get2);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testAllWithNewAxisInMiddle() {
|
||||
public void testAllWithNewAxisInMiddle(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
|
||||
|
||||
|
@ -131,20 +123,20 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(assertion2, get2);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testAllWithNewAxis() {
|
||||
public void testAllWithNewAxis(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||
INDArray get = arr.get(newAxis(), all(), point(1));
|
||||
INDArray assertion = Nd4j.create(new double[][] {{4, 5, 6}, {10, 11, 12}, {16, 17, 18}, {22, 23, 24}})
|
||||
.reshape(1, 4, 3);
|
||||
.reshape(1, 4, 3);
|
||||
assertEquals(assertion, get);
|
||||
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testIndexingWithMmul() {
|
||||
public void testIndexingWithMmul(Nd4jBackend backend) {
|
||||
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
|
||||
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
||||
// System.out.println(b);
|
||||
|
@ -154,9 +146,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(assertion, c);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPointPointInterval() {
|
||||
public void testPointPointInterval(Nd4jBackend backend) {
|
||||
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
|
||||
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
|
||||
INDArray assertion = Nd4j.create(new double[][] {{5, 6}, {8, 9}});
|
||||
|
@ -164,9 +156,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(assertion, get);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testIntervalLowerBound() {
|
||||
public void testIntervalLowerBound(Nd4jBackend backend) {
|
||||
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
|
||||
INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}});
|
||||
|
@ -176,9 +168,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testGetPointRowVector() {
|
||||
public void testGetPointRowVector(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
|
||||
|
||||
INDArray arr2 = arr.get(point(0), interval(0, 100));
|
||||
|
@ -187,9 +179,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSpecifiedIndexVector() {
|
||||
public void testSpecifiedIndexVector(Nd4jBackend backend) {
|
||||
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
|
||||
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
|
||||
INDArray get = rootMatrix.get(all(), new SpecifiedIndex(0, 2));
|
||||
|
@ -205,9 +197,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testPutRowIndexing() {
|
||||
public void testPutRowIndexing(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.ones(1, 10);
|
||||
INDArray row = Nd4j.create(1, 10);
|
||||
|
||||
|
@ -216,9 +208,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(arr, row);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testVectorIndexing2() {
|
||||
public void testVectorIndexing2(Nd4jBackend backend) {
|
||||
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
|
||||
INDArray assertion = Nd4j.create(new double[] {2, 4});
|
||||
assertEquals(assertion, wholeVector);
|
||||
|
@ -232,9 +224,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testOffsetsC() {
|
||||
public void testOffsetsC(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||
assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
|
||||
assertEquals(3, NDArrayIndex.offset(arr, point(1), point(1)));
|
||||
|
@ -249,9 +241,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testIndexFor() {
|
||||
public void testIndexFor(Nd4jBackend backend) {
|
||||
long[] shape = {1, 2};
|
||||
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
|
||||
for (int i = 0; i < indexes.length; i++) {
|
||||
|
@ -259,9 +251,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testGetScalar() {
|
||||
public void testGetScalar(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
||||
INDArray d = arr.get(point(1));
|
||||
assertTrue(d.isScalar());
|
||||
|
@ -269,26 +261,26 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testVectorIndexing() {
|
||||
public void testVectorIndexing(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
|
||||
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
|
||||
INDArray viewTest = arr.get(point(0), interval(1, 5));
|
||||
assertEquals(assertion, viewTest);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testNegativeIndices() {
|
||||
public void testNegativeIndices(Nd4jBackend backend) {
|
||||
INDArray test = Nd4j.create(10, 10, 10);
|
||||
test.putScalar(new int[] {0, 0, -1}, 1.0);
|
||||
assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testGetIndices2d() {
|
||||
public void testGetIndices2d(Nd4jBackend backend) {
|
||||
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
|
||||
INDArray firstRow = twoByTwo.getRow(0);
|
||||
INDArray secondRow = twoByTwo.getRow(1);
|
||||
|
@ -305,9 +297,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testGetRow() {
|
||||
public void testGetRow(Nd4jBackend backend) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
|
||||
int[] toGet = {0, 1};
|
||||
|
@ -323,9 +315,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testGetRowEdgeCase() {
|
||||
public void testGetRowEdgeCase(Nd4jBackend backend) {
|
||||
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
||||
INDArray get = rowVec.getRow(0); //Returning shape [1,1]
|
||||
|
||||
|
@ -333,9 +325,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(rowVec, get);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testGetColumnEdgeCase() {
|
||||
public void testGetColumnEdgeCase(Nd4jBackend backend) {
|
||||
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
|
||||
INDArray get = colVec.getColumn(0); //Returning shape [1,1]
|
||||
|
||||
|
@ -343,9 +335,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(colVec, get);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testConcatColumns() {
|
||||
public void testConcatColumns(Nd4jBackend backend) {
|
||||
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
|
||||
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
|
||||
INDArray concat = Nd4j.concat(1, input1, input2);
|
||||
|
@ -353,18 +345,18 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(assertion, concat);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testGetIndicesVector() {
|
||||
public void testGetIndicesVector(Nd4jBackend backend) {
|
||||
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
|
||||
INDArray test = Nd4j.create(new double[] {2, 3});
|
||||
INDArray result = line.get(point(0), interval(1, 3));
|
||||
assertEquals(test, result);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testArangeMul() {
|
||||
public void testArangeMul(Nd4jBackend backend) {
|
||||
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
|
||||
INDArrayIndex index = interval(0, 2);
|
||||
INDArray get = arange.get(index, index);
|
||||
|
@ -374,7 +366,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
assertEquals(assertion, mul);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testIndexingThorough(){
|
||||
long[] fullShape = {3,4,5,6,7};
|
||||
|
@ -575,7 +567,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
|||
return d;
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void debugging(){
|
||||
long[] inShape = {3,4};
|
||||
|
|
|
@ -46,12 +46,13 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Slf4j
|
||||
|
||||
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
|
||||
|
||||
@TempDir Path testDir;
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void compareAfterWrite(Nd4jBackend backend) throws Exception {
|
||||
int [] ranksToCheck = new int[] {0,1,2,3,4};
|
||||
for (int i = 0; i < ranksToCheck.length; i++) {
|
||||
// log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
||||
|
@ -82,7 +83,7 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testNd4jReadWriteText(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testNd4jReadWriteText(Nd4jBackend backend) throws Exception {
|
||||
|
||||
File dir = testDir.toFile();
|
||||
int count = 0;
|
||||
|
|
|
@ -38,11 +38,11 @@ import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays;
|
|||
@Slf4j
|
||||
|
||||
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
|
||||
|
||||
@TempDir Path testDir;
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void compareAfterWrite(Nd4jBackend backend) throws Exception {
|
||||
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
|
||||
for (int i = 0; i < ranksToCheck.length; i++) {
|
||||
log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
||||
|
|
|
@ -22,6 +22,7 @@ package org.nd4j.linalg.broadcast;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
@ -135,7 +136,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
|||
assertEquals(e, z);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void basicBroadcastFailureTest_1(Nd4jBackend backend) {
|
||||
|
@ -146,7 +146,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void basicBroadcastFailureTest_2(Nd4jBackend backend) {
|
||||
|
@ -158,7 +157,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void basicBroadcastFailureTest_3(Nd4jBackend backend) {
|
||||
|
@ -170,16 +168,15 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void basicBroadcastFailureTest_4(Nd4jBackend backend) {
|
||||
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
||||
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
||||
val z = x.addi(y);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void basicBroadcastFailureTest_5(Nd4jBackend backend) {
|
||||
|
@ -191,7 +188,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void basicBroadcastFailureTest_6(Nd4jBackend backend) {
|
||||
|
@ -249,9 +245,9 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
|||
assertEquals(y, z);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void emptyBroadcastTest_2(Nd4jBackend backend) {
|
||||
val x = Nd4j.create(DataType.FLOAT, 1, 2);
|
||||
val y = Nd4j.create(DataType.FLOAT, 0, 2);
|
||||
|
|
|
@ -37,7 +37,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class CompressionMagicTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
@BeforeEach
|
||||
public void setUp(Nd4jBackend backend) {
|
||||
public void setUp() {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@ import java.util.Set;
|
|||
|
||||
public class DeconvTests extends BaseNd4jTestWithBackends {
|
||||
|
||||
@TempDir Path testDir;
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
@ -56,7 +57,7 @@ public class DeconvTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void compareKeras(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void compareKeras(Nd4jBackend backend) throws Exception {
|
||||
File newFolder = testDir.toFile();
|
||||
new ClassPathResource("keras/deconv/").copyDirectory(newFolder);
|
||||
|
||||
|
|
|
@ -99,7 +99,8 @@ public class SpecialTests extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testScalarShuffle1(Nd4jBackend backend) {
|
||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||
List<DataSet> listData = new ArrayList<>();
|
||||
|
|
|
@ -195,7 +195,8 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
|||
assertEquals(exp, arrayX);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testInplaceOp1(Nd4jBackend backend) {
|
||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||
val arrayX = Nd4j.create(10, 10);
|
||||
|
|
|
@ -41,10 +41,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
|
||||
public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
@TempDir Path testDir;
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testBalance(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testBalance(Nd4jBackend backend) throws Exception {
|
||||
DataSetIterator iterator = new IrisDataSetIterator(10, 150);
|
||||
|
||||
File minibatches = new File(testDir.toFile(),"mini-batch-dir");
|
||||
|
@ -62,7 +63,7 @@ public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testMiniBatchBalanced(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
||||
public void testMiniBatchBalanced(Nd4jBackend backend) throws Exception {
|
||||
|
||||
int miniBatchSize = 100;
|
||||
DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150);
|
||||
|
|
|
@ -51,8 +51,10 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
|
|||
@Slf4j
|
||||
|
||||
public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
@ParameterizedTest
|
||||
|
||||
@TempDir Path testDir;
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testViewIterator(Nd4jBackend backend) {
|
||||
DataSetIterator iter = new ViewIterator(new IrisDataSetIterator(150, 150).next(), 10);
|
||||
|
@ -106,9 +108,9 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSplitTestAndTrain (Nd4jBackend backend) {
|
||||
public void testSplitTestAndTrain(Nd4jBackend backend) {
|
||||
INDArray labels = FeatureUtil.toOutcomeMatrix(new int[] {0, 0, 0, 0, 0, 0, 0, 0}, 1);
|
||||
DataSet data = new DataSet(Nd4j.rand(8, 1), labels);
|
||||
|
||||
|
@ -116,7 +118,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
|||
assertEquals(train.getTrain().getLabels().length(), 6);
|
||||
|
||||
SplitTestAndTrain train2 = data.splitTestAndTrain(6, new Random(1));
|
||||
assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage());
|
||||
assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage(backend));
|
||||
|
||||
DataSet x0 = new IrisDataSetIterator(150, 150).next();
|
||||
SplitTestAndTrain testAndTrain = x0.splitTestAndTrain(10);
|
||||
|
@ -144,7 +146,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
|||
SplitTestAndTrain testAndTrainRng = x2.splitTestAndTrain(10, rngHere);
|
||||
|
||||
assertArrayEquals(testAndTrainRng.getTrain().getFeatures().shape(),
|
||||
testAndTrain.getTrain().getFeatures().shape());
|
||||
testAndTrain.getTrain().getFeatures().shape());
|
||||
assertEquals(testAndTrainRng.getTrain().getFeatures(), testAndTrain.getTrain().getFeatures());
|
||||
assertEquals(testAndTrainRng.getTrain().getLabels(), testAndTrain.getTrain().getLabels());
|
||||
|
||||
|
@ -154,13 +156,13 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
|||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testLabelCounts(Nd4jBackend backend) {
|
||||
DataSet x0 = new IrisDataSetIterator(150, 150).next();
|
||||
assertEquals(0, x0.get(0).outcome(),getFailureMessage());
|
||||
assertEquals( 0, x0.get(1).outcome(),getFailureMessage());
|
||||
assertEquals(2, x0.get(149).outcome(),getFailureMessage());
|
||||
assertEquals(0, x0.get(0).outcome(),getFailureMessage(backend));
|
||||
assertEquals( 0, x0.get(1).outcome(),getFailureMessage(backend));
|
||||
assertEquals(2, x0.get(149).outcome(),getFailureMessage(backend));
|
||||
Map<Integer, Double> counts = x0.labelCounts();
|
||||
assertEquals(50, counts.get(0), 1e-1,getFailureMessage());
|
||||
assertEquals(50, counts.get(1), 1e-1,getFailureMessage());
|
||||
assertEquals(50, counts.get(2), 1e-1,getFailureMessage());
|
||||
assertEquals(50, counts.get(0), 1e-1,getFailureMessage(backend));
|
||||
assertEquals(50, counts.get(1), 1e-1,getFailureMessage(backend));
|
||||
assertEquals(50, counts.get(2), 1e-1,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
|
@ -694,14 +696,14 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
INDArray expLabels3d = Nd4j.create(3, 3, 4);
|
||||
expLabels3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
|
||||
l3d1);
|
||||
l3d1);
|
||||
expLabels3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
|
||||
NDArrayIndex.interval(0, 3)}, l3d2);
|
||||
NDArrayIndex.interval(0, 3)}, l3d2);
|
||||
INDArray expLM3d = Nd4j.create(3, 3, 4);
|
||||
expLM3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
|
||||
lm3d1);
|
||||
lm3d1);
|
||||
expLM3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
|
||||
NDArrayIndex.interval(0, 3)}, lm3d2);
|
||||
NDArrayIndex.interval(0, 3)}, lm3d2);
|
||||
|
||||
|
||||
DataSet merged3d = DataSet.merge(Arrays.asList(ds3d1, ds3d2));
|
||||
|
@ -752,52 +754,52 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
|||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testShuffleNd(Nd4jBackend backend) {
|
||||
int numDims = 7;
|
||||
int nLabels = 3;
|
||||
Random r = new Random();
|
||||
int numDims = 7;
|
||||
int nLabels = 3;
|
||||
Random r = new Random();
|
||||
|
||||
|
||||
int[] shape = new int[numDims];
|
||||
int entries = 1;
|
||||
for (int i = 0; i < numDims; i++) {
|
||||
//randomly generating shapes bigger than 1
|
||||
shape[i] = r.nextInt(4) + 2;
|
||||
entries *= shape[i];
|
||||
}
|
||||
int labels = shape[0] * nLabels;
|
||||
int[] shape = new int[numDims];
|
||||
int entries = 1;
|
||||
for (int i = 0; i < numDims; i++) {
|
||||
//randomly generating shapes bigger than 1
|
||||
shape[i] = r.nextInt(4) + 2;
|
||||
entries *= shape[i];
|
||||
}
|
||||
int labels = shape[0] * nLabels;
|
||||
|
||||
INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
|
||||
INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
|
||||
INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
|
||||
INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
|
||||
|
||||
DataSet ds = new DataSet(ds_data, ds_labels);
|
||||
ds.shuffle();
|
||||
DataSet ds = new DataSet(ds_data, ds_labels);
|
||||
ds.shuffle();
|
||||
|
||||
//Checking Nd dataset which is the data
|
||||
for (int dim = 1; dim < numDims; dim++) {
|
||||
//get tensor along dimension - the order in every dimension but zero should be preserved
|
||||
for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
|
||||
//the difference between consecutive elements should be equal to the stride
|
||||
for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
|
||||
int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
|
||||
int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
|
||||
int f_element_diff = f_next_element - f_element;
|
||||
assertEquals(f_element_diff, ds_data.stride(dim));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Checking 2d, features
|
||||
int dim = 1;
|
||||
//Checking Nd dataset which is the data
|
||||
for (int dim = 1; dim < numDims; dim++) {
|
||||
//get tensor along dimension - the order in every dimension but zero should be preserved
|
||||
for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
|
||||
for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
|
||||
//the difference between consecutive elements should be equal to the stride
|
||||
for (int i = 0, j = 1; j < nLabels; i++, j++) {
|
||||
int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
|
||||
int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
|
||||
int l_element_diff = l_next_element - l_element;
|
||||
assertEquals(l_element_diff, ds_labels.stride(dim));
|
||||
for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
|
||||
int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
|
||||
int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
|
||||
int f_element_diff = f_next_element - f_element;
|
||||
assertEquals(f_element_diff, ds_data.stride(dim));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Checking 2d, features
|
||||
int dim = 1;
|
||||
//get tensor along dimension - the order in every dimension but zero should be preserved
|
||||
for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
|
||||
//the difference between consecutive elements should be equal to the stride
|
||||
for (int i = 0, j = 1; j < nLabels; i++, j++) {
|
||||
int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
|
||||
int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
|
||||
int l_element_diff = l_next_element - l_element;
|
||||
assertEquals(l_element_diff, ds_labels.stride(dim));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -936,9 +938,9 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
//Checking if the features and labels are equal
|
||||
assertEquals(iDataSet.getFeatures(),
|
||||
dsList.get(i).getFeatures().get(all(), all(), interval(0, minTSLength + i)));
|
||||
dsList.get(i).getFeatures().get(all(), all(), interval(0, minTSLength + i)));
|
||||
assertEquals(iDataSet.getLabels(),
|
||||
dsList.get(i).getLabels().get(all(), all(), interval(0, minTSLength + i)));
|
||||
dsList.get(i).getLabels().get(all(), all(), interval(0, minTSLength + i)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -964,8 +966,8 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
|||
for (boolean lMask : b) {
|
||||
|
||||
DataSet ds = new DataSet((features ? f : null),
|
||||
(labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null),
|
||||
(lMask ? lm : null));
|
||||
(labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null),
|
||||
(lMask ? lm : null));
|
||||
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(baos);
|
||||
|
@ -1009,7 +1011,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
|||
boolean lMask = true;
|
||||
|
||||
DataSet ds = new DataSet((features ? f : null), (labels ? (labelsSameAsFeatures ? f : l) : null),
|
||||
(fMask ? fm : null), (lMask ? lm : null));
|
||||
(fMask ? fm : null), (lMask ? lm : null));
|
||||
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(baos);
|
||||
|
@ -1098,7 +1100,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
|
||||
public void testDataSetMetaDataSerialization(Nd4jBackend backend) throws IOException {
|
||||
|
||||
for(boolean withMeta : new boolean[]{false, true}) {
|
||||
// create simple data set with meta data object
|
||||
|
@ -1129,7 +1131,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testMultiDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend nd4jBackend) throws IOException {
|
||||
public void testMultiDataSetMetaDataSerialization(Nd4jBackend nd4jBackend) throws IOException {
|
||||
|
||||
for(boolean withMeta : new boolean[]{false, true}) {
|
||||
// create simple data set with meta data object
|
||||
|
|
|
@ -106,7 +106,8 @@ public class KFoldIteratorTest extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void checkCornerCaseException(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1),
|
||||
|
|
|
@ -21,27 +21,25 @@
|
|||
package org.nd4j.linalg.dataset;
|
||||
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import java.nio.file.Path;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
|
||||
|
||||
public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTestWithBackends {
|
||||
|
||||
@TempDir Path testDir;
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testMiniBatches(@TempDir Path testDir) throws Exception {
|
||||
public void testMiniBatches(Nd4jBackend backend) throws Exception {
|
||||
DataSet load = new IrisDataSetIterator(150, 150).next();
|
||||
final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.toFile());
|
||||
while (iter.hasNext())
|
||||
|
|
|
@ -39,8 +39,7 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
|
|||
return 'c';
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void when_preConditionsIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||
assertThrows(NullPointerException.class,() -> {
|
||||
|
|
|
@ -41,8 +41,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
|||
return 'c';
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void when_originalHeightIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
|
@ -51,8 +50,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void when_originalWidthIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
|
@ -61,8 +59,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void when_yStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
|
@ -71,8 +68,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void when_xStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
|
@ -81,8 +77,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
|
@ -91,8 +86,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
|
@ -101,8 +95,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
|
@ -111,8 +104,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
|||
});
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||
// Assemble
|
||||
|
|
|
@ -39,7 +39,8 @@ public class PermuteDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
|
|||
return 'c';
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||
assertThrows(NullPointerException.class,() -> {
|
||||
// Assemble
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
|
||||
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||
|
@ -39,7 +38,8 @@ public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTestWithBacke
|
|||
return 'c';
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||
assertThrows(NullPointerException.class,() -> {
|
||||
// Assemble
|
||||
|
|
|
@ -139,7 +139,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
|
|||
INDArray actualResult = data.mean(0);
|
||||
INDArray expectedResult = Nd4j.create(new double[] {3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6.,
|
||||
6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6.}, new int[] {2, 4, 4});
|
||||
assertEquals(expectedResult, actualResult,getFailureMessage());
|
||||
assertEquals(expectedResult, actualResult,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
|
||||
|
@ -154,7 +154,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
|
|||
INDArray actualResult = data.var(false, 0);
|
||||
INDArray expectedResult = Nd4j.create(new double[] {1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4.,
|
||||
4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4.}, new long[] {2, 4, 4});
|
||||
assertEquals(expectedResult, actualResult,getFailureMessage());
|
||||
assertEquals(expectedResult, actualResult,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
|
|
@ -83,8 +83,7 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testAccessException_1(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
|
@ -96,8 +95,7 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testAccessException_2(Nd4jBackend backend) {
|
||||
assertThrows(IllegalStateException.class,() -> {
|
||||
|
|
|
@ -384,7 +384,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
|||
assertEquals(exp, arrayZ);
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
|
||||
public void testTypesValidation_1(Nd4jBackend backend) {
|
||||
assertThrows(IllegalArgumentException.class,() -> {
|
||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG);
|
||||
|
@ -397,7 +399,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testTypesValidation_2(Nd4jBackend backend) {
|
||||
assertThrows(RuntimeException.class,() -> {
|
||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||
|
@ -412,7 +415,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@Test()
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testTypesValidation_3(Nd4jBackend backend) {
|
||||
assertThrows(RuntimeException.class,() -> {
|
||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||
|
@ -422,6 +426,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testTypesValidation_4(Nd4jBackend backend) {
|
||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.DOUBLE);
|
||||
|
@ -485,7 +491,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testBoolFloatCast2(){
|
||||
public void testBoolFloatCast2(Nd4jBackend backend){
|
||||
val first = Nd4j.zeros(DataType.FLOAT, 3, 5000);
|
||||
INDArray asBool = first.castTo(DataType.BOOL);
|
||||
INDArray not = Transforms.not(asBool); //
|
||||
|
@ -516,7 +522,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testAssignScalarSimple(){
|
||||
public void testAssignScalarSimple(Nd4jBackend backend){
|
||||
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
||||
INDArray arr = Nd4j.scalar(dt, 10.0);
|
||||
arr.assign(2.0);
|
||||
|
@ -526,7 +532,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testSimple(){
|
||||
public void testSimple(Nd4jBackend backend){
|
||||
Nd4j.create(1);
|
||||
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) {
|
||||
// System.out.println("----- " + dt + " -----");
|
||||
|
@ -551,7 +557,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testWorkspaceBool(){
|
||||
public void testWorkspaceBool(Nd4jBackend backend){
|
||||
val conf = WorkspaceConfiguration.builder().minSize(10 * 1024 * 1024)
|
||||
.overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE)
|
||||
.policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL)
|
||||
|
@ -559,7 +565,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS");
|
||||
|
||||
for( int i=0; i<10; i++ ) {
|
||||
for( int i = 0; i < 10; i++ ) {
|
||||
try (val workspace = (Nd4jWorkspace)ws.notifyScopeEntered() ) {
|
||||
val bool = Nd4j.create(DataType.BOOL, 1, 10);
|
||||
val dbl = Nd4j.create(DataType.DOUBLE, 1, 10);
|
||||
|
@ -574,8 +580,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@Disabled
|
||||
public void testArrayCreationFromPointer(Nd4jBackend backend) {
|
||||
val source = Nd4j.create(new double[]{1, 2, 3, 4, 5});
|
||||
|
||||
|
|
|
@ -40,13 +40,13 @@ public class NativeBlasTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
|
||||
@BeforeEach
|
||||
public void setUp(Nd4jBackend backend) {
|
||||
public void setUp() {
|
||||
Nd4j.getExecutioner().enableDebugMode(true);
|
||||
Nd4j.getExecutioner().enableVerboseMode(true);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void setDown(Nd4jBackend backend) {
|
||||
public void setDown() {
|
||||
Nd4j.getExecutioner().enableDebugMode(false);
|
||||
Nd4j.getExecutioner().enableVerboseMode(false);
|
||||
}
|
||||
|
|
|
@ -77,18 +77,18 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
|
||||
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
|
||||
double sim = Transforms.cosineSim(vec1, vec2);
|
||||
assertEquals( 1, sim, 1e-1,getFailureMessage());
|
||||
assertEquals( 1, sim, 1e-1,getFailureMessage(backend));
|
||||
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testCosineDistance(){
|
||||
public void testCosineDistance(Nd4jBackend backend){
|
||||
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3});
|
||||
INDArray vec2 = Nd4j.create(new float[] {3, 5, 7});
|
||||
// 1-17*sqrt(2/581)
|
||||
double distance = Transforms.cosineDistance(vec1, vec2);
|
||||
assertEquals(0.0025851, distance, 1e-7,getFailureMessage());
|
||||
assertEquals(0.0025851, distance, 1e-7,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -97,7 +97,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
INDArray arr = Nd4j.create(new double[] {55, 55});
|
||||
INDArray arr2 = Nd4j.create(new double[] {60, 60});
|
||||
double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).z().getDouble(0);
|
||||
assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage());
|
||||
assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -137,7 +137,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi();
|
||||
INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6);
|
||||
Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1));
|
||||
assertEquals(scalarMax, postMax,getFailureMessage());
|
||||
assertEquals(scalarMax, postMax,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -147,14 +147,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1));
|
||||
for (int i = 0; i < linspace.length(); i++) {
|
||||
double val = linspace.getDouble(i);
|
||||
assertTrue( val >= 0 && val <= 1,getFailureMessage());
|
||||
assertTrue( val >= 0 && val <= 1,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||
Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4));
|
||||
for (int i = 0; i < linspace2.length(); i++) {
|
||||
double val = linspace2.getDouble(i);
|
||||
assertTrue( val >= 2 && val <= 4,getFailureMessage());
|
||||
assertTrue( val >= 2 && val <= 4,getFailureMessage(backend));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -163,7 +163,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
public void testNormMax(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||
double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).z().getDouble(0);
|
||||
assertEquals(4, normMax, 1e-1,getFailureMessage());
|
||||
assertEquals(4, normMax, 1e-1,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -187,7 +187,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
public void testNorm2(Nd4jBackend backend) {
|
||||
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||
double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).z().getDouble(0);
|
||||
assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage());
|
||||
assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -198,7 +198,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
INDArray xDup = x.dup();
|
||||
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
|
||||
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
|
||||
assertEquals(solution, x,getFailureMessage());
|
||||
assertEquals(solution, x,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -221,13 +221,13 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
INDArray xDup = x.dup();
|
||||
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
|
||||
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
|
||||
assertEquals(solution, x,getFailureMessage());
|
||||
assertEquals(solution, x,getFailureMessage(backend));
|
||||
Sum acc = new Sum(x.dup());
|
||||
opExecutioner.exec(acc);
|
||||
assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
||||
assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||
Prod prod = new Prod(x.dup());
|
||||
opExecutioner.exec(prod);
|
||||
assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
||||
assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
|
||||
|
@ -275,7 +275,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
Variance variance = new Variance(x.dup(), true);
|
||||
opExecutioner.exec(variance);
|
||||
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
||||
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
|
||||
|
@ -284,14 +284,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testIamax(Nd4jBackend backend) {
|
||||
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage());
|
||||
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
public void testIamax2(Nd4jBackend backend) {
|
||||
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage());
|
||||
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
|
||||
val op = new ArgAmax(linspace);
|
||||
|
||||
int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0);
|
||||
|
@ -307,11 +307,11 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
|
||||
Mean mean = new Mean(x);
|
||||
opExecutioner.exec(mean);
|
||||
assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
||||
assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||
|
||||
Variance variance = new Variance(x.dup(), true);
|
||||
opExecutioner.exec(variance);
|
||||
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
||||
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -321,7 +321,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
||||
val softMax = new SoftMax(arr);
|
||||
opExecutioner.exec((CustomOp) softMax);
|
||||
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
||||
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
|
||||
|
@ -332,7 +332,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
Pow pow = new Pow(oneThroughSix, 2);
|
||||
Nd4j.getExecutioner().exec(pow);
|
||||
INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36});
|
||||
assertEquals(answer, pow.z(),getFailureMessage());
|
||||
assertEquals(answer, pow.z(),getFailureMessage(backend));
|
||||
}
|
||||
|
||||
|
||||
|
@ -384,7 +384,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
Log log = new Log(slice);
|
||||
opExecutioner.exec(log);
|
||||
INDArray assertion = Nd4j.create(new double[] {0., 1.09861229, 1.60943791});
|
||||
assertEquals(assertion, slice,getFailureMessage());
|
||||
assertEquals(assertion, slice,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -572,7 +572,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
expected[i] = (float) Math.exp(slice.getDouble(i));
|
||||
Exp exp = new Exp(slice);
|
||||
opExecutioner.exec(exp);
|
||||
assertEquals( Nd4j.create(expected), slice,getFailureMessage());
|
||||
assertEquals( Nd4j.create(expected), slice,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -582,7 +582,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
|||
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
||||
val softMax = new SoftMax(arr);
|
||||
opExecutioner.exec((CustomOp) softMax);
|
||||
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
||||
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue