Unify nd4j test profiles, get rid of old modules, fix more parameter issues with junit 5 tests
parent
e0077c38a9
commit
ad4f47096c
|
@ -31,7 +31,7 @@ jobs:
|
||||||
protoc --version
|
protoc --version
|
||||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
windows-x86_64:
|
windows-x86_64:
|
||||||
runs-on: windows-2019
|
runs-on: windows-2019
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,5 +60,5 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ jobs:
|
||||||
protoc --version
|
protoc --version
|
||||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
windows-x86_64:
|
windows-x86_64:
|
||||||
runs-on: windows-2019
|
runs-on: windows-2019
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,5 +60,5 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
|
@ -34,5 +34,5 @@ jobs:
|
||||||
cmake --version
|
cmake --version
|
||||||
protoc --version
|
protoc --version
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptest-nd4j-native --also-make clean test
|
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Pnd4j-tests-cpu --also-make clean test
|
||||||
|
|
||||||
|
|
|
@ -109,10 +109,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -30,6 +30,8 @@ import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.loader.FileBatch;
|
import org.nd4j.common.loader.FileBatch;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -40,13 +42,16 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
@DisplayName("File Batch Record Reader Test")
|
@DisplayName("File Batch Record Reader Test")
|
||||||
class FileBatchRecordReaderTest extends BaseND4JTest {
|
public class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@DisplayName("Test Csv")
|
@DisplayName("Test Csv")
|
||||||
void testCsv(@TempDir Path testDir) throws Exception {
|
void testCsv(Nd4jBackend backend) throws Exception {
|
||||||
// This is an unrealistic use case - one line/record per CSV
|
// This is an unrealistic use case - one line/record per CSV
|
||||||
File baseDir = testDir.toFile();
|
File baseDir = testDir.toFile();
|
||||||
List<File> fileList = new ArrayList<>();
|
List<File> fileList = new ArrayList<>();
|
||||||
|
@ -75,9 +80,10 @@ class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@DisplayName("Test Csv Sequence")
|
@DisplayName("Test Csv Sequence")
|
||||||
void testCsvSequence(@TempDir Path testDir) throws Exception {
|
void testCsvSequence(Nd4jBackend backend) throws Exception {
|
||||||
// CSV sequence - 3 lines per file, 10 files
|
// CSV sequence - 3 lines per file, 10 files
|
||||||
File baseDir = testDir.toFile();
|
File baseDir = testDir.toFile();
|
||||||
List<File> fileList = new ArrayList<>();
|
List<File> fileList = new ArrayList<>();
|
||||||
|
|
|
@ -21,7 +21,6 @@ package org.datavec.api.transform.ops;
|
||||||
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
|
@ -60,10 +60,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -119,10 +119,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -59,10 +59,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -57,10 +57,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -65,10 +65,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -61,25 +61,18 @@
|
||||||
<artifactId>nd4j-common</artifactId>
|
<artifactId>nd4j-common</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.datavec</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>datavec-geo</artifactId>
|
<artifactId>python4j-numpy</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-python</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -29,7 +29,6 @@ import org.datavec.api.transform.reduce.Reducer;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.schema.SequenceSchema;
|
import org.datavec.api.transform.schema.SequenceSchema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.datavec.python.PythonTransform;
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
@ -39,7 +38,6 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
|
||||||
import static java.time.Duration.ofMillis;
|
import static java.time.Duration.ofMillis;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTimeout;
|
import static org.junit.jupiter.api.Assertions.assertTimeout;
|
||||||
|
|
||||||
|
@ -166,37 +164,8 @@ class ExecutionTest {
|
||||||
List<List<Writable>> out = outRdd;
|
List<List<Writable>> out = outRdd;
|
||||||
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
||||||
out = new ArrayList<>(out);
|
out = new ArrayList<>(out);
|
||||||
Collections.sort(out, new Comparator<List<Writable>>() {
|
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
|
||||||
|
|
||||||
@Override
|
|
||||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
|
||||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
assertEquals(expOut, out);
|
assertEquals(expOut, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
|
|
||||||
@DisplayName("Test Python Execution Ndarray")
|
|
||||||
void testPythonExecutionNdarray() {
|
|
||||||
assertTimeout(ofMillis(60000), () -> {
|
|
||||||
Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
|
|
||||||
TransformProcess transformProcess = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
|
|
||||||
List<List<Writable>> functions = new ArrayList<>();
|
|
||||||
List<Writable> firstRow = new ArrayList<>();
|
|
||||||
INDArray firstArr = Nd4j.linspace(1, 4, 4);
|
|
||||||
INDArray secondArr = Nd4j.linspace(1, 4, 4);
|
|
||||||
firstRow.add(new NDArrayWritable(firstArr));
|
|
||||||
firstRow.add(new NDArrayWritable(secondArr));
|
|
||||||
functions.add(firstRow);
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
|
|
||||||
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
|
|
||||||
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
|
|
||||||
INDArray expected = Transforms.sin(firstArr);
|
|
||||||
INDArray secondExpected = Transforms.cos(secondArr);
|
|
||||||
assertEquals(expected, firstResult);
|
|
||||||
assertEquals(secondExpected, secondResult);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,155 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.local.transforms.transform;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnType;
|
|
||||||
import org.datavec.api.transform.Transform;
|
|
||||||
import org.datavec.api.transform.geo.LocationType;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
|
|
||||||
import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
|
|
||||||
import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
|
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author saudet
|
|
||||||
*/
|
|
||||||
public class TestGeoTransforms {
|
|
||||||
|
|
||||||
@BeforeAll
|
|
||||||
public static void beforeClass() throws Exception {
|
|
||||||
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
|
|
||||||
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
|
|
||||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void afterClass(){
|
|
||||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCoordinatesDistanceTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
assertEquals(4, out.numColumns());
|
|
||||||
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
|
|
||||||
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
|
|
||||||
out.getColumnTypes());
|
|
||||||
|
|
||||||
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
|
|
||||||
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
|
|
||||||
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
|
|
||||||
new DoubleWritable(Math.sqrt(160))),
|
|
||||||
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
|
|
||||||
new Text("10|5"))));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testIPAddressToCoordinatesTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
|
||||||
|
|
||||||
Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
|
|
||||||
assertEquals(1, out.getColumnMetaData().size());
|
|
||||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
|
||||||
|
|
||||||
String in = "81.2.69.160";
|
|
||||||
double latitude = 51.5142;
|
|
||||||
double longitude = -0.0931;
|
|
||||||
|
|
||||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
|
||||||
assertEquals(2, coordinates.length);
|
|
||||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
|
||||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
|
||||||
|
|
||||||
//Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
ObjectOutputStream oos = new ObjectOutputStream(baos);
|
|
||||||
oos.writeObject(transform);
|
|
||||||
|
|
||||||
byte[] bytes = baos.toByteArray();
|
|
||||||
|
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
|
||||||
ObjectInputStream ois = new ObjectInputStream(bais);
|
|
||||||
|
|
||||||
Transform deserialized = (Transform) ois.readObject();
|
|
||||||
writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
|
||||||
//System.out.println(Arrays.toString(coordinates));
|
|
||||||
assertEquals(2, coordinates.length);
|
|
||||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
|
||||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testIPAddressToLocationTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
|
||||||
LocationType[] locationTypes = LocationType.values();
|
|
||||||
String in = "81.2.69.160";
|
|
||||||
String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
|
|
||||||
"51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
|
|
||||||
|
|
||||||
for (int i = 0; i < locationTypes.length; i++) {
|
|
||||||
LocationType locationType = locationTypes[i];
|
|
||||||
String location = locations[i];
|
|
||||||
|
|
||||||
Transform transform = new IPAddressToLocationTransform("column", locationType);
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
|
|
||||||
assertEquals(1, out.getColumnMetaData().size());
|
|
||||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
|
||||||
|
|
||||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
assertEquals(location, writables.get(0).toString());
|
|
||||||
//System.out.println(location);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,386 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.local.transforms.transform;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.transform.condition.Condition;
|
|
||||||
import org.datavec.api.transform.filter.ConditionFilter;
|
|
||||||
import org.datavec.api.transform.filter.Filter;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
|
||||||
|
|
||||||
import org.datavec.api.writable.*;
|
|
||||||
import org.datavec.python.PythonCondition;
|
|
||||||
import org.datavec.python.PythonTransform;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.junit.jupiter.api.Timeout;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import javax.annotation.concurrent.NotThreadSafe;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
import static org.datavec.api.transform.schema.Schema.Builder;
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
|
||||||
|
|
||||||
@NotThreadSafe
|
|
||||||
public class TestPythonTransformProcess {
|
|
||||||
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
public void testStringConcat() throws Exception{
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnString("col1")
|
|
||||||
.addColumnString("col2");
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnString("col3");
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build()
|
|
||||||
).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals((outputs.get(0)).toString(), "Hello ");
|
|
||||||
assertEquals((outputs.get(1)).toString(), "World!");
|
|
||||||
assertEquals((outputs.get(2)).toString(), "Hello World!");
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testMixedTypes() throws Exception {
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnInteger("col1")
|
|
||||||
.addColumnFloat("col2")
|
|
||||||
.addColumnString("col3")
|
|
||||||
.addColumnDouble("col4");
|
|
||||||
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnInteger("col5");
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.inputSchema(initialSchema)
|
|
||||||
.build() ).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(new IntWritable(10),
|
|
||||||
new FloatWritable(3.5f),
|
|
||||||
new Text("5"),
|
|
||||||
new DoubleWritable(2.0)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testNDArray() throws Exception {
|
|
||||||
long[] shape = new long[]{3, 2};
|
|
||||||
INDArray arr1 = Nd4j.rand(shape);
|
|
||||||
INDArray arr2 = Nd4j.rand(shape);
|
|
||||||
|
|
||||||
INDArray expectedOutput = arr1.add(arr2);
|
|
||||||
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnNDArray("col1", shape)
|
|
||||||
.addColumnNDArray("col2", shape);
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnNDArray("col3", shape);
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build() ).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new NDArrayWritable(arr1),
|
|
||||||
new NDArrayWritable(arr2)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
|
||||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
|
||||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testNDArray2() throws Exception {
|
|
||||||
long[] shape = new long[]{3, 2};
|
|
||||||
INDArray arr1 = Nd4j.rand(shape);
|
|
||||||
INDArray arr2 = Nd4j.rand(shape);
|
|
||||||
|
|
||||||
INDArray expectedOutput = arr1.add(arr2);
|
|
||||||
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnNDArray("col1", shape)
|
|
||||||
.addColumnNDArray("col2", shape);
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnNDArray("col3", shape);
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build() ).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new NDArrayWritable(arr1),
|
|
||||||
new NDArrayWritable(arr2)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
|
||||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
|
||||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testNDArrayMixed() throws Exception{
|
|
||||||
long[] shape = new long[]{3, 2};
|
|
||||||
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
|
|
||||||
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
|
|
||||||
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
|
|
||||||
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnNDArray("col1", shape)
|
|
||||||
.addColumnNDArray("col2", shape);
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnNDArray("col3", shape);
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build()
|
|
||||||
).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new NDArrayWritable(arr1),
|
|
||||||
new NDArrayWritable(arr2)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
|
||||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
|
||||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testPythonFilter() {
|
|
||||||
Schema schema = new Builder().addColumnInteger("column").build();
|
|
||||||
|
|
||||||
Condition condition = new PythonCondition(
|
|
||||||
"f = lambda: column < 0"
|
|
||||||
);
|
|
||||||
|
|
||||||
condition.setInputSchema(schema);
|
|
||||||
|
|
||||||
Filter filter = new ConditionFilter(condition);
|
|
||||||
|
|
||||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
|
|
||||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
|
|
||||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
|
|
||||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
|
|
||||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testPythonFilterAndTransform() throws Exception {
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnInteger("col1")
|
|
||||||
.addColumnFloat("col2")
|
|
||||||
.addColumnString("col3")
|
|
||||||
.addColumnDouble("col4");
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnString("col6");
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
Condition condition = new PythonCondition(
|
|
||||||
"f = lambda: col1 < 0 and col2 > 10.0"
|
|
||||||
);
|
|
||||||
|
|
||||||
condition.setInputSchema(initialSchema);
|
|
||||||
|
|
||||||
Filter filter = new ConditionFilter(condition);
|
|
||||||
|
|
||||||
String pythonCode = "col6 = str(col1 + col2)";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build()
|
|
||||||
).filter(
|
|
||||||
filter
|
|
||||||
).build();
|
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(
|
|
||||||
Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new IntWritable(5),
|
|
||||||
new FloatWritable(3.0f),
|
|
||||||
new Text("abcd"),
|
|
||||||
new DoubleWritable(2.1))
|
|
||||||
);
|
|
||||||
inputs.add(
|
|
||||||
Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new IntWritable(-3),
|
|
||||||
new FloatWritable(3.0f),
|
|
||||||
new Text("abcd"),
|
|
||||||
new DoubleWritable(2.1))
|
|
||||||
);
|
|
||||||
inputs.add(
|
|
||||||
Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new IntWritable(5),
|
|
||||||
new FloatWritable(11.2f),
|
|
||||||
new Text("abcd"),
|
|
||||||
new DoubleWritable(2.1))
|
|
||||||
);
|
|
||||||
|
|
||||||
LocalTransformExecutor.execute(inputs,tp);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testPythonTransformNoOutputSpecified() throws Exception {
|
|
||||||
PythonTransform pythonTransform = PythonTransform.builder()
|
|
||||||
.code("a += 2; b = 'hello world'")
|
|
||||||
.returnAllInputs(true)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
|
|
||||||
Schema inputSchema = new Builder()
|
|
||||||
.addColumnInteger("a")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
|
||||||
.transform(pythonTransform)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
|
||||||
assertEquals(3,execute.get(0).get(0).toInt());
|
|
||||||
assertEquals("hello world",execute.get(0).get(1).toString());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNumpyTransform() {
|
|
||||||
PythonTransform pythonTransform = PythonTransform.builder()
|
|
||||||
.code("a += 2; b = 'hello world'")
|
|
||||||
.returnAllInputs(true)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
|
|
||||||
Schema inputSchema = new Builder()
|
|
||||||
.addColumnNDArray("a",new long[]{1,1})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
|
||||||
.transform(pythonTransform)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
|
||||||
assertFalse(execute.isEmpty());
|
|
||||||
assertNotNull(execute.get(0));
|
|
||||||
assertNotNull(execute.get(0).get(0));
|
|
||||||
assertNotNull(execute.get(0).get(1));
|
|
||||||
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
|
|
||||||
assertEquals("hello world",execute.get(0).get(1).toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testWithSetupRun() throws Exception {
|
|
||||||
|
|
||||||
PythonTransform pythonTransform = PythonTransform.builder()
|
|
||||||
.code("five=None\n" +
|
|
||||||
"def setup():\n" +
|
|
||||||
" global five\n"+
|
|
||||||
" five = 5\n\n" +
|
|
||||||
"def run(a, b):\n" +
|
|
||||||
" c = a + b + five\n"+
|
|
||||||
" return {'c':c}\n\n")
|
|
||||||
.returnAllInputs(true)
|
|
||||||
.setupAndRun(true)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
|
|
||||||
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
|
|
||||||
Schema inputSchema = new Builder()
|
|
||||||
.addColumnNDArray("a",new long[]{1,1})
|
|
||||||
.addColumnNDArray("b", new long[]{1, 1})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
|
||||||
.transform(pythonTransform)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
|
||||||
assertFalse(execute.isEmpty());
|
|
||||||
assertNotNull(execute.get(0));
|
|
||||||
assertNotNull(execute.get(0).get(0));
|
|
||||||
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -128,10 +128,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -92,6 +92,10 @@
|
||||||
<groupId>org.junit.jupiter</groupId>
|
<groupId>org.junit.jupiter</groupId>
|
||||||
<artifactId>junit-jupiter-api</artifactId>
|
<artifactId>junit-jupiter-api</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-params</artifactId>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.junit.vintage</groupId>
|
<groupId>org.junit.vintage</groupId>
|
||||||
<artifactId>junit-vintage-engine</artifactId>
|
<artifactId>junit-vintage-engine</artifactId>
|
||||||
|
@ -154,7 +158,7 @@
|
||||||
<skip>${skipTestResourceEnforcement}</skip>
|
<skip>${skipTestResourceEnforcement}</skip>
|
||||||
<rules>
|
<rules>
|
||||||
<requireActiveProfile>
|
<requireActiveProfile>
|
||||||
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles>
|
<profiles>nd4j-tests-cpu,nd4j-tests-cuda</profiles>
|
||||||
<all>false</all>
|
<all>false</all>
|
||||||
</requireActiveProfile>
|
</requireActiveProfile>
|
||||||
</rules>
|
</rules>
|
||||||
|
@ -163,23 +167,6 @@
|
||||||
</execution>
|
</execution>
|
||||||
</executions>
|
</executions>
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<configuration>
|
|
||||||
<argLine></argLine>
|
|
||||||
|
|
||||||
<!--
|
|
||||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
|
||||||
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
|
|
||||||
function correctly without this configuration.
|
|
||||||
For example, tests for custom transforms (where the custom transform is defined in the test directory)
|
|
||||||
will fail due to the custom transform not being found on the classpath.
|
|
||||||
http://maven.apache.org/surefire/maven-surefire-plugin/examples/class-loading.html
|
|
||||||
-->
|
|
||||||
<useSystemClassLoader>true</useSystemClassLoader>
|
|
||||||
<useManifestOnlyJar>false</useManifestOnlyJar>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.eclipse.m2e</groupId>
|
<groupId>org.eclipse.m2e</groupId>
|
||||||
<artifactId>lifecycle-mapping</artifactId>
|
<artifactId>lifecycle-mapping</artifactId>
|
||||||
|
@ -249,7 +236,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -266,7 +253,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -286,9 +273,6 @@
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
|
||||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
</build>
|
</build>
|
||||||
|
|
|
@ -64,7 +64,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -75,7 +75,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -56,10 +56,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -166,7 +166,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -177,7 +177,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -23,7 +23,6 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
|
||||||
|
@ -34,7 +33,6 @@ import java.util.List;
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
|
||||||
|
|
||||||
@DisplayName("Early Termination Data Set Iterator Test")
|
@DisplayName("Early Termination Data Set Iterator Test")
|
||||||
class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
|
class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
|
@ -21,19 +21,16 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
|
||||||
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import org.junit.jupiter.api.DisplayName;
|
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@DisplayName("Early Termination Multi Data Set Iterator Test")
|
@DisplayName("Early Termination Multi Data Set Iterator Test")
|
||||||
|
|
|
@ -34,7 +34,6 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -46,7 +45,6 @@ import java.util.Random;
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
|
||||||
|
|
||||||
@Disabled
|
@Disabled
|
||||||
@DisplayName("Attention Layer Test")
|
@DisplayName("Attention Layer Test")
|
||||||
|
|
|
@ -105,11 +105,12 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
<plugin>
|
<plugin>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
|
<inherited>true</inherited>
|
||||||
<configuration>
|
<configuration>
|
||||||
<skip>true</skip>
|
<skip>true</skip>
|
||||||
</configuration>
|
</configuration>
|
||||||
|
@ -118,7 +119,7 @@
|
||||||
</build>
|
</build>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -56,10 +56,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -50,10 +50,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -45,10 +45,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -54,10 +54,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -112,7 +112,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -123,7 +123,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -72,10 +72,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -306,7 +306,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -317,7 +317,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -101,10 +101,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -49,10 +49,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -127,10 +127,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -102,7 +102,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -113,7 +113,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -99,10 +99,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -44,10 +44,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -89,10 +89,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -88,10 +88,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -90,10 +90,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -105,10 +105,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -182,10 +182,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -77,10 +77,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -104,10 +104,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -141,10 +141,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -79,10 +79,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -426,10 +426,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
|
@ -44,10 +44,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -87,10 +87,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -117,10 +117,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
|
@ -143,6 +143,10 @@
|
||||||
</extensions>
|
</extensions>
|
||||||
|
|
||||||
<plugins>
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-enforcer-plugin</artifactId>
|
<artifactId>maven-enforcer-plugin</artifactId>
|
||||||
|
@ -158,7 +162,7 @@
|
||||||
<skip>${skipBackendChoice}</skip>
|
<skip>${skipBackendChoice}</skip>
|
||||||
<rules>
|
<rules>
|
||||||
<requireActiveProfile>
|
<requireActiveProfile>
|
||||||
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles>
|
<profiles>nd4j-tests-cpu,nd4j-tests-cuda</profiles>
|
||||||
<all>false</all>
|
<all>false</all>
|
||||||
</requireActiveProfile>
|
</requireActiveProfile>
|
||||||
</rules>
|
</rules>
|
||||||
|
@ -227,43 +231,6 @@
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
|
||||||
<pluginManagement>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<inherited>true</inherited>
|
|
||||||
<configuration>
|
|
||||||
<!--
|
|
||||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
|
||||||
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
|
|
||||||
function correctly without this configuration.
|
|
||||||
For example, tests for custom layers (where the custom layer is defined in the test directory)
|
|
||||||
will fail due to the custom layer not being found on the classpath.
|
|
||||||
http://maven.apache.org/surefire/maven-surefire-plugin/examples/class-loading.html
|
|
||||||
-->
|
|
||||||
<useSystemClassLoader>true</useSystemClassLoader>
|
|
||||||
<useManifestOnlyJar>false</useManifestOnlyJar>
|
|
||||||
<argLine> -Dfile.encoding=UTF-8 -Xmx8g "</argLine>
|
|
||||||
<includes>
|
|
||||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
|
||||||
<include>*.java</include>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
</includes>
|
|
||||||
</configuration>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.maven.surefire</groupId>
|
|
||||||
<artifactId>surefire-junit-platform</artifactId>
|
|
||||||
<version>${maven-surefire-plugin.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.eclipse.m2e</groupId>
|
|
||||||
<artifactId>lifecycle-mapping</artifactId>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</pluginManagement>
|
|
||||||
</build>
|
</build>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
@ -290,10 +257,10 @@
|
||||||
<module>deeplearning4j-cuda</module>
|
<module>deeplearning4j-cuda</module>
|
||||||
</modules>
|
</modules>
|
||||||
</profile>
|
</profile>
|
||||||
<!-- For running unit tests with nd4j-native: "mvn clean test -P test-nd4j-native"
|
<!-- For running unit tests with nd4j-native: "mvn clean test -P nd4j-tests-cpu"
|
||||||
Note that this excludes DL4J-cuda -->
|
Note that this excludes DL4J-cuda -->
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<activation>
|
<activation>
|
||||||
<activeByDefault>false</activeByDefault>
|
<activeByDefault>false</activeByDefault>
|
||||||
</activation>
|
</activation>
|
||||||
|
@ -311,70 +278,10 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<inherited>true</inherited>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-native</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.junit.jupiter</groupId>
|
|
||||||
<artifactId>junit-jupiter-engine</artifactId>
|
|
||||||
<version>${junit.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.junit.jupiter</groupId>
|
|
||||||
<artifactId>junit-jupiter-params</artifactId>
|
|
||||||
<version>${junit.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.maven.surefire</groupId>
|
|
||||||
<artifactId>surefire-junit-platform</artifactId>
|
|
||||||
<version>${maven-surefire-plugin.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<configuration>
|
|
||||||
<environmentVariables>
|
|
||||||
|
|
||||||
</environmentVariables>
|
|
||||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>*.java</include>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
<include>**/Test*.java</include>
|
|
||||||
<include>**/*Test.java</include>
|
|
||||||
<include>**/*TestCase.java</include>
|
|
||||||
</includes>
|
|
||||||
<junitArtifactName>org.junit.jupiter:junit-jupiter-engine</junitArtifactName>
|
|
||||||
<systemPropertyVariables>
|
|
||||||
<org.nd4j.linalg.defaultbackend>
|
|
||||||
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
||||||
</org.nd4j.linalg.defaultbackend>
|
|
||||||
<org.nd4j.linalg.tests.backendstorun>
|
|
||||||
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
||||||
</org.nd4j.linalg.tests.backendstorun>
|
|
||||||
</systemPropertyVariables>
|
|
||||||
<!--
|
|
||||||
Maximum heap size was set to 8g, as a minimum required value for tests run.
|
|
||||||
Depending on a build machine, default value is not always enough.
|
|
||||||
|
|
||||||
For testing large zoo models, this may not be enough (so comment it out).
|
|
||||||
-->
|
|
||||||
<argLine></argLine>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
</profile>
|
</profile>
|
||||||
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
|
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<activation>
|
<activation>
|
||||||
<activeByDefault>false</activeByDefault>
|
<activeByDefault>false</activeByDefault>
|
||||||
</activation>
|
</activation>
|
||||||
|
@ -392,43 +299,6 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
<!-- Default to ALL modules here, unlike nd4j-native -->
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<version>${maven-surefire-plugin.version}</version>
|
|
||||||
<configuration>
|
|
||||||
<environmentVariables>
|
|
||||||
</environmentVariables>
|
|
||||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>*.java</include>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
<include>**/Test*.java</include>
|
|
||||||
<include>**/*Test.java</include>
|
|
||||||
<include>**/*TestCase.java</include>
|
|
||||||
</includes>
|
|
||||||
<junitArtifactName>org.junit.jupiter:junit-jupiter</junitArtifactName>
|
|
||||||
<systemPropertyVariables>
|
|
||||||
<org.nd4j.linalg.defaultbackend>
|
|
||||||
org.nd4j.linalg.jcublas.JCublasBackend
|
|
||||||
</org.nd4j.linalg.defaultbackend>
|
|
||||||
<org.nd4j.linalg.tests.backendstorun>
|
|
||||||
org.nd4j.linalg.jcublas.JCublasBackend
|
|
||||||
</org.nd4j.linalg.tests.backendstorun>
|
|
||||||
</systemPropertyVariables>
|
|
||||||
<!--
|
|
||||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
|
||||||
Depending on a build machine, default value is not always enough.
|
|
||||||
-->
|
|
||||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
|
||||||
</configuration>
|
|
||||||
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -5,7 +5,7 @@ Linux
|
||||||
[INFO] Total time: 14.610 s
|
[INFO] Total time: 14.610 s
|
||||||
[INFO] Finished at: 2021-03-06T15:35:28+09:00
|
[INFO] Finished at: 2021-03-06T15:35:28+09:00
|
||||||
[INFO] ------------------------------------------------------------------------
|
[INFO] ------------------------------------------------------------------------
|
||||||
[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist.
|
[WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
|
||||||
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
|
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
|
||||||
[ERROR]
|
[ERROR]
|
||||||
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
|
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
|
||||||
|
@ -749,7 +749,7 @@ make[1]: Leaving directory '/c/Users/agibs/Documents/GitHub/eclipse-deeplearning
|
||||||
[INFO] Total time: 15.482 s
|
[INFO] Total time: 15.482 s
|
||||||
[INFO] Finished at: 2021-03-06T15:27:35+09:00
|
[INFO] Finished at: 2021-03-06T15:27:35+09:00
|
||||||
[INFO] ------------------------------------------------------------------------
|
[INFO] ------------------------------------------------------------------------
|
||||||
[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist.
|
[WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
|
||||||
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
|
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
|
||||||
[ERROR]
|
[ERROR]
|
||||||
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
|
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
Const,in_0
|
||||||
|
Const,while/Const
|
||||||
|
Const,while/add/y
|
||||||
|
Identity,in_0/read
|
||||||
|
Enter,while/Enter
|
||||||
|
Enter,while/Enter_1
|
||||||
|
Merge,while/Merge
|
||||||
|
Merge,while/Merge_1
|
||||||
|
Less,while/Less
|
||||||
|
LoopCond,while/LoopCond
|
||||||
|
Switch,while/Switch
|
||||||
|
Switch,while/Switch_1
|
||||||
|
Identity,while/Identity
|
||||||
|
Exit,while/Exit
|
||||||
|
Identity,while/Identity_1
|
||||||
|
Exit,while/Exit_1
|
||||||
|
Add,while/add
|
||||||
|
NextIteration,while/NextIteration_1
|
||||||
|
NextIteration,while/NextIteration
|
|
@ -0,0 +1,16 @@
|
||||||
|
Identity,in_0/read
|
||||||
|
Enter,while/Enter
|
||||||
|
Enter,while/Enter_1
|
||||||
|
Merge,while/Merge
|
||||||
|
Merge,while/Merge_1
|
||||||
|
Less,while/Less
|
||||||
|
LoopCond,while/LoopCond
|
||||||
|
Switch,while/Switch
|
||||||
|
Switch,while/Switch_1
|
||||||
|
Identity,while/Identity
|
||||||
|
Exit,while/Exit
|
||||||
|
Identity,while/Identity_1
|
||||||
|
Exit,while/Exit_1
|
||||||
|
Add,while/add
|
||||||
|
NextIteration,while/NextIteration_1
|
||||||
|
NextIteration,while/NextIteration
|
|
@ -0,0 +1,19 @@
|
||||||
|
in_0
|
||||||
|
while/Const
|
||||||
|
while/add/y
|
||||||
|
in_0/read
|
||||||
|
while/Enter
|
||||||
|
while/Enter_1
|
||||||
|
while/Merge
|
||||||
|
while/Merge_1
|
||||||
|
while/Less
|
||||||
|
while/LoopCond
|
||||||
|
while/Switch
|
||||||
|
while/Switch_1
|
||||||
|
while/Identity
|
||||||
|
while/Exit
|
||||||
|
while/Identity_1
|
||||||
|
while/Exit_1
|
||||||
|
while/add
|
||||||
|
while/NextIteration_1
|
||||||
|
while/NextIteration
|
|
@ -303,7 +303,7 @@
|
||||||
|
|
||||||
For testing large zoo models, this may not be enough (so comment it out).
|
For testing large zoo models, this may not be enough (so comment it out).
|
||||||
-->
|
-->
|
||||||
<argLine>-Dfile.encoding=UTF-8 "</argLine>
|
<argLine>-Dfile.encoding=UTF-8 </argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
|
|
@ -27,6 +27,7 @@ import java.util.List;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.TestInfo;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -482,7 +483,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConv3d(Nd4jBackend backend) {
|
public void testConv3d(Nd4jBackend backend, TestInfo testInfo) {
|
||||||
//Pooling3d, Conv3D, batch norm
|
//Pooling3d, Conv3D, batch norm
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -573,7 +574,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
tc.testName(msg);
|
tc.testName(msg);
|
||||||
String error = OpValidation.validate(tc);
|
String error = OpValidation.validate(tc);
|
||||||
if (error != null) {
|
if (error != null) {
|
||||||
failed.add(name);
|
failed.add(testInfo.getTestMethod().get().getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1353,7 +1354,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
|
public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
|
@ -1382,7 +1384,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
|
public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -1405,7 +1408,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
|
public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -22,6 +22,7 @@ package org.nd4j.autodiff.opvalidation;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -664,6 +665,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void testMmulGradientManual(Nd4jBackend backend) {
|
public void testMmulGradientManual(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
|
|
|
@ -69,7 +69,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
public void tearDown(Nd4jBackend backend) {
|
public void tearDown() {
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ import lombok.val;
|
||||||
import org.apache.commons.math3.linear.LUDecomposition;
|
import org.apache.commons.math3.linear.LUDecomposition;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.TestInfo;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.OpValidationSuite;
|
import org.nd4j.OpValidationSuite;
|
||||||
|
@ -83,7 +84,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConcat(Nd4jBackend backend) {
|
public void testConcat(Nd4jBackend backend, TestInfo testInfo) {
|
||||||
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
|
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
|
||||||
int[] concatDim = new int[]{0, 0, 0};
|
int[] concatDim = new int[]{0, 0, 0};
|
||||||
List<List<int[]>> origShapes = new ArrayList<>();
|
List<List<int[]>> origShapes = new ArrayList<>();
|
||||||
|
@ -115,7 +116,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
String error = OpValidation.validate(tc);
|
String error = OpValidation.validate(tc);
|
||||||
if(error != null){
|
if(error != null){
|
||||||
failed.add(name);
|
failed.add(testInfo.getTestMethod().get().getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -285,7 +286,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSqueezeGradient(Nd4jBackend backend) {
|
public void testSqueezeGradient(Nd4jBackend backend,TestInfo testInfo) {
|
||||||
val origShape = new long[]{3, 4, 5};
|
val origShape = new long[]{3, 4, 5};
|
||||||
|
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
@ -339,7 +340,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
String error = OpValidation.validate(tc, true);
|
String error = OpValidation.validate(tc, true);
|
||||||
if(error != null){
|
if(error != null){
|
||||||
failed.add(name);
|
failed.add(testInfo.getTestMethod().get().getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -580,8 +581,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
return Long.MAX_VALUE;
|
return Long.MAX_VALUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
public void testStack(Nd4jBackend backend) {
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
public void testStack(Nd4jBackend backend,TestInfo testInfo) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
@ -661,7 +663,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
String error = OpValidation.validate(tc);
|
String error = OpValidation.validate(tc);
|
||||||
if(error != null){
|
if(error != null){
|
||||||
failed.add(name);
|
failed.add(testInfo.getTestMethod().get().getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,6 +72,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering(){
|
public char ordering(){
|
||||||
|
@ -82,7 +84,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testBasic(Nd4jBackend backend) throws Exception {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
|
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
|
||||||
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
|
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
|
||||||
|
@ -121,7 +123,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
int numOutputs = fg.outputsLength();
|
int numOutputs = fg.outputsLength();
|
||||||
List<IntPair> outputs = new ArrayList<>(numOutputs);
|
List<IntPair> outputs = new ArrayList<>(numOutputs);
|
||||||
for( int i=0; i<numOutputs; i++ ){
|
for( int i = 0; i < numOutputs; i++) {
|
||||||
outputs.add(fg.outputs(i));
|
outputs.add(fg.outputs(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,7 +140,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testSimple(Nd4jBackend backend) throws Exception {
|
||||||
for( int i = 0; i < 10; i++ ) {
|
for( int i = 0; i < 10; i++ ) {
|
||||||
for(boolean execFirst : new boolean[]{false, true}) {
|
for(boolean execFirst : new boolean[]{false, true}) {
|
||||||
log.info("Starting test: i={}, execFirst={}", i, execFirst);
|
log.info("Starting test: i={}, execFirst={}", i, execFirst);
|
||||||
|
@ -268,7 +270,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testTrainingSerde(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
//Ensure 2 things:
|
//Ensure 2 things:
|
||||||
//1. Training config is serialized/deserialized correctly
|
//1. Training config is serialized/deserialized correctly
|
||||||
|
|
|
@ -109,7 +109,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void before(Nd4jBackend backend) {
|
public void before() {
|
||||||
Nd4j.create(1);
|
Nd4j.create(1);
|
||||||
initialType = Nd4j.dataType();
|
initialType = Nd4j.dataType();
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
public void after(Nd4jBackend backend) {
|
public void after() {
|
||||||
Nd4j.setDataType(initialType);
|
Nd4j.setDataType(initialType);
|
||||||
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
||||||
|
|
|
@ -21,9 +21,6 @@
|
||||||
package org.nd4j.autodiff.samediff.listeners;
|
package org.nd4j.autodiff.samediff.listeners;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -47,11 +44,11 @@ import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -97,7 +94,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointEveryEpoch(Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
|
|
||||||
SameDiff sd = getModel();
|
SameDiff sd = getModel();
|
||||||
|
@ -132,7 +129,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointEvery5Iter(Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
|
|
||||||
SameDiff sd = getModel();
|
SameDiff sd = getModel();
|
||||||
|
@ -172,7 +169,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointListenerEveryTimeUnit(Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
SameDiff sd = getModel();
|
SameDiff sd = getModel();
|
||||||
|
|
||||||
|
@ -217,7 +214,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointListenerKeepLast3AndEvery3(Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
SameDiff sd = getModel();
|
SameDiff sd = getModel();
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ package org.nd4j.autodiff.samediff.listeners;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
@ -49,6 +50,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
|
||||||
public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
|
public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
|
@ -59,7 +61,8 @@ public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
@Disabled
|
||||||
|
public void testProfilingListenerSimple(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
|
||||||
|
|
|
@ -64,6 +64,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering(){
|
public char ordering(){
|
||||||
|
@ -81,7 +82,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
|
public void testSimple(Nd4jBackend backend) throws IOException {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
|
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
|
||||||
SDVariable sum = v.sum();
|
SDVariable sum = v.sum();
|
||||||
|
@ -163,7 +164,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
//Append a number of events
|
//Append a number of events
|
||||||
w.registerEventName("accuracy");
|
w.registerEventName("accuracy");
|
||||||
for( int iter=0; iter<3; iter++) {
|
for( int iter = 0; iter < 3; iter++) {
|
||||||
long t = System.currentTimeMillis();
|
long t = System.currentTimeMillis();
|
||||||
w.writeScalarEvent("accuracy", LogFileWriter.EventSubtype.EVALUATION, t, iter, 0, 0.5 + 0.1 * iter);
|
w.writeScalarEvent("accuracy", LogFileWriter.EventSubtype.EVALUATION, t, iter, 0, 0.5 + 0.1 * iter);
|
||||||
}
|
}
|
||||||
|
@ -175,7 +176,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||||
UIAddName addName = (UIAddName) events.get(0).getRight();
|
UIAddName addName = (UIAddName) events.get(0).getRight();
|
||||||
assertEquals("accuracy", addName.name());
|
assertEquals("accuracy", addName.name());
|
||||||
|
|
||||||
for( int i=1; i<4; i++ ){
|
for( int i = 1; i < 4; i++ ){
|
||||||
FlatArray fa = (FlatArray) events.get(i).getRight();
|
FlatArray fa = (FlatArray) events.get(i).getRight();
|
||||||
INDArray arr = Nd4j.createFromFlatArray(fa);
|
INDArray arr = Nd4j.createFromFlatArray(fa);
|
||||||
|
|
||||||
|
@ -186,7 +187,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
|
public void testNullBinLabels(Nd4jBackend backend) throws Exception{
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
File f = new File(dir, "temp.bin");
|
File f = new File(dir, "temp.bin");
|
||||||
LogFileWriter w = new LogFileWriter(f);
|
LogFileWriter w = new LogFileWriter(f);
|
||||||
|
|
|
@ -56,6 +56,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
public class UIListenerTest extends BaseNd4jTestWithBackends {
|
public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
@ -65,7 +67,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testUIListenerBasic(Nd4jBackend backend) throws Exception {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
@ -102,7 +104,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testUIListenerContinue(Nd4jBackend backend) throws Exception {
|
||||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
SameDiff sd1 = getSimpleNet();
|
SameDiff sd1 = getSimpleNet();
|
||||||
|
@ -194,7 +196,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testUIListenerBadContinue(Nd4jBackend backend) throws Exception {
|
||||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
SameDiff sd1 = getSimpleNet();
|
SameDiff sd1 = getSimpleNet();
|
||||||
|
|
||||||
|
@ -275,7 +277,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private static SameDiff getSimpleNet(){
|
private static SameDiff getSimpleNet() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4);
|
||||||
|
|
|
@ -22,6 +22,7 @@ package org.nd4j.evaluation;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -48,6 +49,7 @@ public class NewInstanceTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void testNewInstances(Nd4jBackend backend) {
|
public void testNewInstances(Nd4jBackend backend) {
|
||||||
boolean print = true;
|
boolean print = true;
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
package org.nd4j.evaluation;
|
package org.nd4j.evaluation;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.evaluation.classification.ROC;
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
|
@ -48,8 +48,9 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void testROCBinary(Nd4jBackend backend) {
|
public void testROCBinary(Nd4jBackend backend) {
|
||||||
//Compare ROCBinary to ROC class
|
//Compare ROCBinary to ROC class
|
||||||
|
|
||||||
|
@ -144,7 +145,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRocBinaryMerging(Nd4jBackend backend) {
|
public void testRocBinaryMerging(Nd4jBackend backend) {
|
||||||
for (int nSteps : new int[]{30, 0}) { //0 == exact
|
for (int nSteps : new int[]{30, 0}) { //0 == exact
|
||||||
|
@ -175,7 +176,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
|
public void testROCBinaryPerOutputMasking(Nd4jBackend backend) {
|
||||||
|
|
||||||
|
@ -216,7 +217,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testROCBinary3d(Nd4jBackend backend) {
|
public void testROCBinary3d(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
||||||
|
@ -251,7 +252,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testROCBinary4d(Nd4jBackend backend) {
|
public void testROCBinary4d(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
|
@ -286,7 +287,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testROCBinary3dMasking(Nd4jBackend backend) {
|
public void testROCBinary3dMasking(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
||||||
|
@ -348,7 +349,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testROCBinary4dMasking(Nd4jBackend backend) {
|
public void testROCBinary4dMasking(Nd4jBackend backend) {
|
||||||
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.nd4j.evaluation;
|
package org.nd4j.evaluation;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -82,16 +83,16 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
expFPR.put(10 / 10.0, 0.0 / totalNegatives);
|
expFPR.put(10 / 10.0, 0.0 / totalNegatives);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRocBasic(Nd4jBackend backend) {
|
public void testRocBasic(Nd4jBackend backend) {
|
||||||
//2 outputs here - probability distribution over classes (softmax)
|
//2 outputs here - probability distribution over classes (softmax)
|
||||||
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
INDArray predictions = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||||
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
||||||
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
||||||
|
|
||||||
INDArray actual = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
|
INDArray actual = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
|
||||||
{0, 1}, {0, 1}});
|
{0, 1}, {0, 1}});
|
||||||
|
|
||||||
ROC roc = new ROC(10);
|
ROC roc = new ROC(10);
|
||||||
roc.eval(actual, predictions);
|
roc.eval(actual, predictions);
|
||||||
|
@ -126,15 +127,15 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(1.0, auc, 1e-6);
|
assertEquals(1.0, auc, 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRocBasicSingleClass(Nd4jBackend backend) {
|
public void testRocBasicSingleClass(Nd4jBackend backend) {
|
||||||
//1 output here - single probability value (sigmoid)
|
//1 output here - single probability value (sigmoid)
|
||||||
|
|
||||||
//add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
//add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||||
INDArray predictions =
|
INDArray predictions =
|
||||||
Nd4j.create(new double[] {0.001, 0.101, 0.201, 0.301, 0.401, 0.501, 0.601, 0.701, 0.801, 0.901},
|
Nd4j.create(new double[] {0.001, 0.101, 0.201, 0.301, 0.401, 0.501, 0.601, 0.701, 0.801, 0.901},
|
||||||
new int[] {10, 1});
|
new int[] {10, 1});
|
||||||
|
|
||||||
INDArray actual = Nd4j.create(new double[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}, new int[] {10, 1});
|
INDArray actual = Nd4j.create(new double[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}, new int[] {10, 1});
|
||||||
|
|
||||||
|
@ -165,7 +166,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRoc(Nd4jBackend backend) {
|
public void testRoc(Nd4jBackend backend) {
|
||||||
//Previous tests allowed for a perfect classifier with right threshold...
|
//Previous tests allowed for a perfect classifier with right threshold...
|
||||||
|
@ -173,7 +174,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});
|
INDArray labels = Nd4j.create(new double[][] {{0, 1}, {0, 1}, {1, 0}, {1, 0}, {1, 0}});
|
||||||
|
|
||||||
INDArray prediction = Nd4j.create(new double[][] {{0.199, 0.801}, {0.499, 0.501}, {0.399, 0.601},
|
INDArray prediction = Nd4j.create(new double[][] {{0.199, 0.801}, {0.499, 0.501}, {0.399, 0.601},
|
||||||
{0.799, 0.201}, {0.899, 0.101}});
|
{0.799, 0.201}, {0.899, 0.101}});
|
||||||
|
|
||||||
Map<Double, Double> expTPR = new HashMap<>();
|
Map<Double, Double> expTPR = new HashMap<>();
|
||||||
double totalPositives = 2.0;
|
double totalPositives = 2.0;
|
||||||
|
@ -251,27 +252,27 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
|
public void testRocTimeSeriesNoMasking(Nd4jBackend backend) {
|
||||||
//Same as first test...
|
//Same as first test...
|
||||||
|
|
||||||
//2 outputs here - probability distribution over classes (softmax)
|
//2 outputs here - probability distribution over classes (softmax)
|
||||||
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||||
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
||||||
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
||||||
|
|
||||||
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
|
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
|
||||||
{0, 1}, {0, 1}});
|
{0, 1}, {0, 1}});
|
||||||
|
|
||||||
INDArray predictions3d = Nd4j.create(2, 2, 5);
|
INDArray predictions3d = Nd4j.create(2, 2, 5);
|
||||||
INDArray firstTSp =
|
INDArray firstTSp =
|
||||||
predictions3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
|
predictions3d.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
|
||||||
assertArrayEquals(new long[] {5, 2}, firstTSp.shape());
|
assertArrayEquals(new long[] {5, 2}, firstTSp.shape());
|
||||||
firstTSp.assign(predictions2d.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
|
firstTSp.assign(predictions2d.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
|
||||||
|
|
||||||
INDArray secondTSp =
|
INDArray secondTSp =
|
||||||
predictions3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
|
predictions3d.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).transpose();
|
||||||
assertArrayEquals(new long[] {5, 2}, secondTSp.shape());
|
assertArrayEquals(new long[] {5, 2}, secondTSp.shape());
|
||||||
secondTSp.assign(predictions2d.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
|
secondTSp.assign(predictions2d.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
|
||||||
|
|
||||||
|
@ -299,23 +300,23 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
|
public void testRocTimeSeriesMasking(Nd4jBackend backend) {
|
||||||
//2 outputs here - probability distribution over classes (softmax)
|
//2 outputs here - probability distribution over classes (softmax)
|
||||||
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
INDArray predictions2d = Nd4j.create(new double[][] {{1.0, 0.001}, //add 0.001 to avoid numerical/rounding issues (float vs. double, etc)
|
||||||
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
{0.899, 0.101}, {0.799, 0.201}, {0.699, 0.301}, {0.599, 0.401}, {0.499, 0.501}, {0.399, 0.601},
|
||||||
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
{0.299, 0.701}, {0.199, 0.801}, {0.099, 0.901}});
|
||||||
|
|
||||||
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
|
INDArray actual2d = Nd4j.create(new double[][] {{1, 0}, {1, 0}, {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1},
|
||||||
{0, 1}, {0, 1}});
|
{0, 1}, {0, 1}});
|
||||||
|
|
||||||
|
|
||||||
//Create time series data... first time series: length 4. Second time series: length 6
|
//Create time series data... first time series: length 4. Second time series: length 6
|
||||||
INDArray predictions3d = Nd4j.create(2, 2, 6);
|
INDArray predictions3d = Nd4j.create(2, 2, 6);
|
||||||
INDArray tad = predictions3d.tensorAlongDimension(0, 1, 2).transpose();
|
INDArray tad = predictions3d.tensorAlongDimension(0, 1, 2).transpose();
|
||||||
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
|
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
|
||||||
.assign(predictions2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
|
.assign(predictions2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
|
||||||
|
|
||||||
tad = predictions3d.tensorAlongDimension(1, 1, 2).transpose();
|
tad = predictions3d.tensorAlongDimension(1, 1, 2).transpose();
|
||||||
tad.assign(predictions2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
|
tad.assign(predictions2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
|
||||||
|
@ -324,7 +325,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
INDArray labels3d = Nd4j.create(2, 2, 6);
|
INDArray labels3d = Nd4j.create(2, 2, 6);
|
||||||
tad = labels3d.tensorAlongDimension(0, 1, 2).transpose();
|
tad = labels3d.tensorAlongDimension(0, 1, 2).transpose();
|
||||||
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
|
tad.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all())
|
||||||
.assign(actual2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
|
.assign(actual2d.get(NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
|
||||||
|
|
||||||
tad = labels3d.tensorAlongDimension(1, 1, 2).transpose();
|
tad = labels3d.tensorAlongDimension(1, 1, 2).transpose();
|
||||||
tad.assign(actual2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
|
tad.assign(actual2d.get(NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
|
||||||
|
@ -350,7 +351,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
|
public void testCompareRocAndRocMultiClass(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -381,7 +382,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCompare2Vs3Classes(Nd4jBackend backend) {
|
public void testCompare2Vs3Classes(Nd4jBackend backend) {
|
||||||
|
|
||||||
|
@ -431,7 +432,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testROCMerging(Nd4jBackend backend) {
|
public void testROCMerging(Nd4jBackend backend) {
|
||||||
int nArrays = 10;
|
int nArrays = 10;
|
||||||
|
@ -477,7 +478,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testROCMerging2(Nd4jBackend backend) {
|
public void testROCMerging2(Nd4jBackend backend) {
|
||||||
int nArrays = 10;
|
int nArrays = 10;
|
||||||
|
@ -523,7 +524,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testROCMultiMerging(Nd4jBackend backend) {
|
public void testROCMultiMerging(Nd4jBackend backend) {
|
||||||
|
|
||||||
|
@ -572,7 +573,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAUCPrecisionRecall(Nd4jBackend backend) {
|
public void testAUCPrecisionRecall(Nd4jBackend backend) {
|
||||||
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
|
//Assume 2 positive examples, at 0.33 and 0.66 predicted, 1 negative example at 0.25 prob
|
||||||
|
@ -620,7 +621,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRocAucExact(Nd4jBackend backend) {
|
public void testRocAucExact(Nd4jBackend backend) {
|
||||||
|
|
||||||
|
@ -681,20 +682,20 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
double[] p = new double[] {0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503, 0.5955447, 0.96451452,
|
double[] p = new double[] {0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503, 0.5955447, 0.96451452,
|
||||||
0.6531771, 0.74890664, 0.65356987, 0.74771481, 0.96130674, 0.0083883, 0.10644438, 0.29870371,
|
0.6531771, 0.74890664, 0.65356987, 0.74771481, 0.96130674, 0.0083883, 0.10644438, 0.29870371,
|
||||||
0.65641118, 0.80981255, 0.87217591, 0.9646476, 0.72368535, 0.64247533, 0.71745362, 0.46759901,
|
0.65641118, 0.80981255, 0.87217591, 0.9646476, 0.72368535, 0.64247533, 0.71745362, 0.46759901,
|
||||||
0.32558468, 0.43964461, 0.72968908, 0.99401459, 0.67687371, 0.79082252, 0.17091426};
|
0.32558468, 0.43964461, 0.72968908, 0.99401459, 0.67687371, 0.79082252, 0.17091426};
|
||||||
|
|
||||||
double[] l = new double[] {1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
|
double[] l = new double[] {1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
|
||||||
0, 1};
|
0, 1};
|
||||||
|
|
||||||
double[] fpr_skl = new double[] {0.0, 0.0, 0.15789474, 0.15789474, 0.31578947, 0.31578947, 0.52631579,
|
double[] fpr_skl = new double[] {0.0, 0.0, 0.15789474, 0.15789474, 0.31578947, 0.31578947, 0.52631579,
|
||||||
0.52631579, 0.68421053, 0.68421053, 0.84210526, 0.84210526, 0.89473684, 0.89473684, 1.0};
|
0.52631579, 0.68421053, 0.68421053, 0.84210526, 0.84210526, 0.89473684, 0.89473684, 1.0};
|
||||||
double[] tpr_skl = new double[] {0.0, 0.09090909, 0.09090909, 0.18181818, 0.18181818, 0.36363636, 0.36363636,
|
double[] tpr_skl = new double[] {0.0, 0.09090909, 0.09090909, 0.18181818, 0.18181818, 0.36363636, 0.36363636,
|
||||||
0.45454545, 0.45454545, 0.72727273, 0.72727273, 0.90909091, 0.90909091, 1.0, 1.0};
|
0.45454545, 0.45454545, 0.72727273, 0.72727273, 0.90909091, 0.90909091, 1.0, 1.0};
|
||||||
//Note the change to the last value: same TPR and FPR at 0.0083883 and 0.0 -> we add the 0.0 threshold edge case + combine with the previous one. Same result
|
//Note the change to the last value: same TPR and FPR at 0.0083883 and 0.0 -> we add the 0.0 threshold edge case + combine with the previous one. Same result
|
||||||
double[] thr_skl = new double[] {1.0, 0.99401459, 0.96130674, 0.92961609, 0.79082252, 0.74771481, 0.67687371,
|
double[] thr_skl = new double[] {1.0, 0.99401459, 0.96130674, 0.92961609, 0.79082252, 0.74771481, 0.67687371,
|
||||||
0.65641118, 0.64247533, 0.46759901, 0.31637555, 0.20456028, 0.18391881, 0.17091426, 0.0};
|
0.65641118, 0.64247533, 0.46759901, 0.31637555, 0.20456028, 0.18391881, 0.17091426, 0.0};
|
||||||
|
|
||||||
INDArray prob = Nd4j.create(p, new int[] {30, 1});
|
INDArray prob = Nd4j.create(p, new int[] {30, 1});
|
||||||
INDArray label = Nd4j.create(l, new int[] {30, 1});
|
INDArray label = Nd4j.create(l, new int[] {30, 1});
|
||||||
|
@ -784,7 +785,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
|
public void rocExactEdgeCaseReallocation(Nd4jBackend backend) {
|
||||||
|
|
||||||
|
@ -797,7 +798,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
|
public void testPrecisionRecallCurveGetPointMethods(Nd4jBackend backend) {
|
||||||
double[] threshold = new double[101];
|
double[] threshold = new double[101];
|
||||||
|
@ -814,15 +815,15 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
PrecisionRecallCurve prc = new PrecisionRecallCurve(threshold, precision, recall, null, null, null, -1);
|
PrecisionRecallCurve prc = new PrecisionRecallCurve(threshold, precision, recall, null, null, null, -1);
|
||||||
|
|
||||||
PrecisionRecallCurve.Point[] points = new PrecisionRecallCurve.Point[] {
|
PrecisionRecallCurve.Point[] points = new PrecisionRecallCurve.Point[] {
|
||||||
//Test exact:
|
//Test exact:
|
||||||
prc.getPointAtThreshold(0.05), prc.getPointAtPrecision(0.05), prc.getPointAtRecall(1 - 0.05),
|
prc.getPointAtThreshold(0.05), prc.getPointAtPrecision(0.05), prc.getPointAtRecall(1 - 0.05),
|
||||||
|
|
||||||
//Test approximate (point doesn't exist exactly). When it doesn't exist:
|
//Test approximate (point doesn't exist exactly). When it doesn't exist:
|
||||||
//Threshold: lowest threshold equal to or exceeding the specified threshold value
|
//Threshold: lowest threshold equal to or exceeding the specified threshold value
|
||||||
//Precision: lowest threshold equal to or exceeding the specified precision value
|
//Precision: lowest threshold equal to or exceeding the specified precision value
|
||||||
//Recall: highest threshold equal to or exceeding the specified recall value
|
//Recall: highest threshold equal to or exceeding the specified recall value
|
||||||
prc.getPointAtThreshold(0.0495), prc.getPointAtPrecision(0.0495),
|
prc.getPointAtThreshold(0.0495), prc.getPointAtPrecision(0.0495),
|
||||||
prc.getPointAtRecall(1 - 0.0505)};
|
prc.getPointAtRecall(1 - 0.0505)};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -834,7 +835,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
|
public void testPrecisionRecallCurveConfusion(Nd4jBackend backend) {
|
||||||
//Sanity check: values calculated from the confusion matrix should match the PR curve values
|
//Sanity check: values calculated from the confusion matrix should match the PR curve values
|
||||||
|
@ -843,7 +844,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
ROC r = new ROC(0, removeRedundantPts);
|
ROC r = new ROC(0, removeRedundantPts);
|
||||||
|
|
||||||
INDArray labels = Nd4j.getExecutioner()
|
INDArray labels = Nd4j.getExecutioner()
|
||||||
.exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE,100, 1), 0.5));
|
.exec(new BernoulliDistribution(Nd4j.createUninitialized(DataType.DOUBLE,100, 1), 0.5));
|
||||||
INDArray probs = Nd4j.rand(100, 1);
|
INDArray probs = Nd4j.rand(100, 1);
|
||||||
|
|
||||||
r.eval(labels, probs);
|
r.eval(labels, probs);
|
||||||
|
@ -874,7 +875,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRocMerge(){
|
public void testRocMerge(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -919,7 +920,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(auprc, auprcAct, 1e-6);
|
assertEquals(auprc, auprcAct, 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRocMultiMerge(){
|
public void testRocMultiMerge(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -931,9 +932,9 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
int nOut = 5;
|
int nOut = 5;
|
||||||
|
|
||||||
Random r = new Random(12345);
|
Random r = new Random(12345);
|
||||||
for( int i=0; i<10; i++ ){
|
for( int i = 0; i < 10; i++ ){
|
||||||
INDArray labels = Nd4j.zeros(3, nOut);
|
INDArray labels = Nd4j.zeros(3, nOut);
|
||||||
for( int j=0; j<3; j++ ){
|
for( int j = 0; j < 3; j++) {
|
||||||
labels.putScalar(j, r.nextInt(nOut), 1.0 );
|
labels.putScalar(j, r.nextInt(nOut), 1.0 );
|
||||||
}
|
}
|
||||||
INDArray out = Nd4j.rand(3, nOut);
|
INDArray out = Nd4j.rand(3, nOut);
|
||||||
|
@ -956,7 +957,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
roc1.merge(roc2);
|
roc1.merge(roc2);
|
||||||
|
|
||||||
for( int i=0; i<nOut; i++ ) {
|
for( int i = 0; i < nOut; i++) {
|
||||||
|
|
||||||
double aucExp = roc.calculateAUC(i);
|
double aucExp = roc.calculateAUC(i);
|
||||||
double auprc = roc.calculateAUCPR(i);
|
double auprc = roc.calculateAUCPR(i);
|
||||||
|
@ -969,9 +970,10 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRocBinaryMerge(){
|
@Disabled
|
||||||
|
public void testRocBinaryMerge(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
ROCBinary roc = new ROCBinary();
|
ROCBinary roc = new ROCBinary();
|
||||||
|
@ -980,7 +982,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
int nOut = 5;
|
int nOut = 5;
|
||||||
|
|
||||||
for( int i=0; i<10; i++ ){
|
for( int i = 0; i < 10; i++) {
|
||||||
INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(3, nOut),0.5));
|
INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(3, nOut),0.5));
|
||||||
INDArray out = Nd4j.rand(3, nOut);
|
INDArray out = Nd4j.rand(3, nOut);
|
||||||
out.diviColumnVector(out.sum(1));
|
out.diviColumnVector(out.sum(1));
|
||||||
|
@ -1015,7 +1017,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSegmentationBinary(){
|
public void testSegmentationBinary(){
|
||||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||||
|
@ -1106,7 +1108,7 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSegmentation(){
|
public void testSegmentation(){
|
||||||
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
|
||||||
|
|
|
@ -50,7 +50,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvalParameters(Nd4jBackend backend) {
|
public void testEvalParameters(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
int specCols = 5;
|
int specCols = 5;
|
||||||
|
|
|
@ -152,7 +152,7 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
public void maskWhenMerge(Nd4jBackend backend) {
|
public void maskWhenMerge(Nd4jBackend backend) {
|
||||||
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
|
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
|
||||||
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
|
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
|
||||||
List<DataSet> dataSetList = new ArrayList<DataSet>();
|
List<DataSet> dataSetList = new ArrayList<>();
|
||||||
dataSetList.add(dsA);
|
dataSetList.add(dsA);
|
||||||
dataSetList.add(dsB);
|
dataSetList.add(dsB);
|
||||||
DataSet fullDataSet = DataSet.merge(dataSetList);
|
DataSet fullDataSet = DataSet.merge(dataSetList);
|
||||||
|
@ -175,7 +175,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
// System.out.println(b);
|
// System.out.println(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
//broken at a threshold
|
//broken at a threshold
|
||||||
public void testArgMax(Nd4jBackend backend) {
|
public void testArgMax(Nd4jBackend backend) {
|
||||||
int max = 63;
|
int max = 63;
|
||||||
|
@ -263,7 +264,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
// log.info("p50: {}; avg: {};", times.get(times.size() / 2), time);
|
// log.info("p50: {}; avg: {};", times.get(times.size() / 2), time);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void checkIllegalElementOps(Nd4jBackend backend) {
|
public void checkIllegalElementOps(Nd4jBackend backend) {
|
||||||
assertThrows(Exception.class,() -> {
|
assertThrows(Exception.class,() -> {
|
||||||
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
|
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
|
||||||
|
@ -328,13 +330,13 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
reshaped.getDouble(i);
|
reshaped.getDouble(i);
|
||||||
}
|
}
|
||||||
for (int j=0;j<arr.slices();j++) {
|
for (int j=0;j<arr.slices();j++) {
|
||||||
for (int k=0;k<arr.slice(j).length();k++) {
|
for (int k = 0; k < arr.slice(j).length(); k++) {
|
||||||
// log.info("\nArr: slice " + j + " element " + k + " " + arr.slice(j).getDouble(k));
|
// log.info("\nArr: slice " + j + " element " + k + " " + arr.slice(j).getDouble(k));
|
||||||
arr.slice(j).getDouble(k);
|
arr.slice(j).getDouble(k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int j=0;j<reshaped.slices();j++) {
|
for (int j = 0;j < reshaped.slices(); j++) {
|
||||||
for (int k=0;k<reshaped.slice(j).length();k++) {
|
for (int k = 0;k < reshaped.slice(j).length(); k++) {
|
||||||
// log.info("\nReshaped: slice " + j + " element " + k + " " + reshaped.slice(j).getDouble(k));
|
// log.info("\nReshaped: slice " + j + " element " + k + " " + reshaped.slice(j).getDouble(k));
|
||||||
reshaped.slice(j).getDouble(k);
|
reshaped.slice(j).getDouble(k);
|
||||||
}
|
}
|
||||||
|
|
|
@ -245,7 +245,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, false);
|
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, false);
|
||||||
assertEquals(sorted[1], sorted2);
|
assertEquals(sorted[1], sorted2);
|
||||||
INDArray shouldIndex = Nd4j.create(new double[] {1, 1, 0, 0}, new long[] {2, 2});
|
INDArray shouldIndex = Nd4j.create(new double[] {1, 1, 0, 0}, new long[] {2, 2});
|
||||||
assertEquals(shouldIndex, sorted[0],getFailureMessage());
|
assertEquals(shouldIndex, sorted[0],getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -266,7 +266,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, true);
|
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, true);
|
||||||
assertEquals(sorted[1], sorted2);
|
assertEquals(sorted[1], sorted2);
|
||||||
INDArray shouldIndex = Nd4j.create(new double[] {0, 0, 1, 1}, new long[] {2, 2});
|
INDArray shouldIndex = Nd4j.create(new double[] {0, 0, 1, 1}, new long[] {2, 2});
|
||||||
assertEquals(shouldIndex, sorted[0],getFailureMessage());
|
assertEquals(shouldIndex, sorted[0],getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -328,13 +328,13 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
public void testDivide(Nd4jBackend backend) {
|
public void testDivide(Nd4jBackend backend) {
|
||||||
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
|
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
|
||||||
INDArray div = two.div(two);
|
INDArray div = two.div(two);
|
||||||
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage());
|
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray half = Nd4j.create(new float[] {0.5f, 0.5f, 0.5f, 0.5f}, new long[] {2, 2});
|
INDArray half = Nd4j.create(new float[] {0.5f, 0.5f, 0.5f, 0.5f}, new long[] {2, 2});
|
||||||
INDArray divi = Nd4j.create(new float[] {0.3f, 0.6f, 0.9f, 0.1f}, new long[] {2, 2});
|
INDArray divi = Nd4j.create(new float[] {0.3f, 0.6f, 0.9f, 0.1f}, new long[] {2, 2});
|
||||||
INDArray assertion = Nd4j.create(new float[] {1.6666666f, 0.8333333f, 0.5555556f, 5}, new long[] {2, 2});
|
INDArray assertion = Nd4j.create(new float[] {1.6666666f, 0.8333333f, 0.5555556f, 5}, new long[] {2, 2});
|
||||||
INDArray result = half.div(divi);
|
INDArray result = half.div(divi);
|
||||||
assertEquals( assertion, result,getFailureMessage());
|
assertEquals( assertion, result,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -344,7 +344,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
|
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
|
||||||
INDArray sigmoid = Transforms.sigmoid(n, false);
|
INDArray sigmoid = Transforms.sigmoid(n, false);
|
||||||
assertEquals( assertion, sigmoid,getFailureMessage());
|
assertEquals( assertion, sigmoid,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -354,7 +354,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
||||||
INDArray neg = Transforms.neg(n);
|
INDArray neg = Transforms.neg(n);
|
||||||
assertEquals(assertion, neg,getFailureMessage());
|
assertEquals(assertion, neg,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,12 +365,12 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
double sim = Transforms.cosineSim(vec1, vec2);
|
double sim = Transforms.cosineSim(vec1, vec2);
|
||||||
assertEquals(1, sim, 1e-1,getFailureMessage());
|
assertEquals(1, sim, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
|
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
|
||||||
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
|
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
|
||||||
sim = Transforms.cosineSim(vec3, vec4);
|
sim = Transforms.cosineSim(vec3, vec4);
|
||||||
assertEquals(0.98, sim, 1e-1,getFailureMessage());
|
assertEquals(0.98, sim, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -621,7 +621,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray innerProduct = n.mmul(transposed);
|
INDArray innerProduct = n.mmul(transposed);
|
||||||
|
|
||||||
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
|
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
|
||||||
assertEquals(scalar, innerProduct,getFailureMessage());
|
assertEquals(scalar, innerProduct,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -678,7 +678,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray five = Nd4j.ones(5);
|
INDArray five = Nd4j.ones(5);
|
||||||
five.addi(five.dup());
|
five.addi(five.dup());
|
||||||
INDArray twos = Nd4j.valueArrayOf(5, 2);
|
INDArray twos = Nd4j.valueArrayOf(5, 2);
|
||||||
assertEquals(twos, five,getFailureMessage());
|
assertEquals(twos, five,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -692,7 +692,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
||||||
|
|
||||||
INDArray test = arr.mmul(arr.transpose());
|
INDArray test = arr.mmul(arr.transpose());
|
||||||
assertEquals(assertion, test,getFailureMessage());
|
assertEquals(assertion, test,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -704,7 +704,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.exec(new PrintVariable(newSlice));
|
Nd4j.exec(new PrintVariable(newSlice));
|
||||||
log.info("Slice: {}", newSlice);
|
log.info("Slice: {}", newSlice);
|
||||||
n.putSlice(0, newSlice);
|
n.putSlice(0, newSlice);
|
||||||
assertEquals( newSlice, n.slice(0),getFailureMessage());
|
assertEquals( newSlice, n.slice(0),getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -713,7 +713,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
|
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
|
||||||
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
|
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
|
||||||
linear.putScalar(new long[] {0, 1}, 1);
|
linear.putScalar(new long[] {0, 1}, 1);
|
||||||
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage());
|
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1059,7 +1059,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray nClone = n1.add(n2);
|
INDArray nClone = n1.add(n2);
|
||||||
assertEquals(Nd4j.scalar(3), nClone);
|
assertEquals(Nd4j.scalar(3), nClone);
|
||||||
INDArray n1PlusN2 = n1.add(n2);
|
INDArray n1PlusN2 = n1.add(n2);
|
||||||
assertFalse(n1PlusN2.equals(n1),getFailureMessage());
|
assertFalse(n1PlusN2.equals(n1),getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray n3 = Nd4j.scalar(3.0);
|
INDArray n3 = Nd4j.scalar(3.0);
|
||||||
INDArray n4 = Nd4j.scalar(4.0);
|
INDArray n4 = Nd4j.scalar(4.0);
|
||||||
|
|
|
@ -156,7 +156,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
DataType initialType = Nd4j.dataType();
|
DataType initialType = Nd4j.dataType();
|
||||||
Level1 l1 = Nd4j.getBlasWrapper().level1();
|
Level1 l1 = Nd4j.getBlasWrapper().level1();
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
|
@ -255,7 +255,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSerialization(@TempDir Path testDir) throws Exception {
|
public void testSerialization(Nd4jBackend backend) throws Exception {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray arr = Nd4j.rand(1, 20);
|
INDArray arr = Nd4j.rand(1, 20);
|
||||||
|
|
||||||
|
@ -339,7 +339,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(assertion,shapeTest);
|
assertArrayEquals(assertion,shapeTest);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled //temporary till libnd4j implements general broadcasting
|
@Disabled //temporary till libnd4j implements general broadcasting
|
||||||
public void testAutoBroadcastAdd(Nd4jBackend backend) {
|
public void testAutoBroadcastAdd(Nd4jBackend backend) {
|
||||||
INDArray left = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,1,2,1);
|
INDArray left = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,1,2,1);
|
||||||
|
@ -370,9 +371,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
n.divi(Nd4j.scalar(1.0d));
|
n.divi(Nd4j.scalar(1.0d));
|
||||||
|
|
||||||
n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
|
n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
|
||||||
assertEquals(27, n.sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(27, n.sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
INDArray a = n.slice(2);
|
INDArray a = n.slice(2);
|
||||||
assertEquals( true, Arrays.equals(new long[] {3, 3}, a.shape()),getFailureMessage());
|
assertEquals( true, Arrays.equals(new long[] {3, 3}, a.shape()),getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -478,12 +479,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
||||||
|
|
||||||
INDArray test = arr.mmul(arr.transpose());
|
INDArray test = arr.mmul(arr.transpose());
|
||||||
assertEquals(assertion, test,getFailureMessage());
|
assertEquals(assertion, test,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void testMmulOp() throws Exception {
|
public void testMmulOp(Nd4jBackend backend) throws Exception {
|
||||||
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
||||||
INDArray z = Nd4j.create(2, 2);
|
INDArray z = Nd4j.create(2, 2);
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
||||||
|
@ -494,7 +496,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
DynamicCustomOp op = new Mmul(arr, arr, z, mMulTranspose);
|
DynamicCustomOp op = new Mmul(arr, arr, z, mMulTranspose);
|
||||||
Nd4j.getExecutioner().execAndReturn(op);
|
Nd4j.getExecutioner().execAndReturn(op);
|
||||||
|
|
||||||
assertEquals(assertion, z,getFailureMessage());
|
assertEquals(assertion, z,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -505,7 +507,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray row1 = oneThroughFour.getRow(1).dup();
|
INDArray row1 = oneThroughFour.getRow(1).dup();
|
||||||
oneThroughFour.subiRowVector(row1);
|
oneThroughFour.subiRowVector(row1);
|
||||||
INDArray result = Nd4j.create(new double[] {-2, -2, 0, 0}, new long[] {2, 2});
|
INDArray result = Nd4j.create(new double[] {-2, -2, 0, 0}, new long[] {2, 2});
|
||||||
assertEquals(result, oneThroughFour,getFailureMessage());
|
assertEquals(result, oneThroughFour,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1093,7 +1095,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(expAllOnes.all());
|
assertTrue(expAllOnes.all());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void testSumAlongDim1sEdgeCases(Nd4jBackend backend) {
|
public void testSumAlongDim1sEdgeCases(Nd4jBackend backend) {
|
||||||
val shapes = new long[][] {
|
val shapes = new long[][] {
|
||||||
|
@ -1227,7 +1230,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray row1 = oneThroughFour.getRow(1);
|
INDArray row1 = oneThroughFour.getRow(1);
|
||||||
row1.addi(1);
|
row1.addi(1);
|
||||||
INDArray result = Nd4j.create(new double[] {1, 2, 4, 5}, new long[] {2, 2});
|
INDArray result = Nd4j.create(new double[] {1, 2, 4, 5}, new long[] {2, 2});
|
||||||
assertEquals(result, oneThroughFour,getFailureMessage());
|
assertEquals(result, oneThroughFour,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1241,8 +1244,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray linear = test.reshape(-1);
|
INDArray linear = test.reshape(-1);
|
||||||
linear.putScalar(2, 6);
|
linear.putScalar(2, 6);
|
||||||
linear.putScalar(3, 7);
|
linear.putScalar(3, 7);
|
||||||
assertEquals(6, linear.getFloat(2), 1e-1,getFailureMessage());
|
assertEquals(6, linear.getFloat(2), 1e-1,getFailureMessage(backend));
|
||||||
assertEquals(7, linear.getFloat(3), 1e-1,getFailureMessage());
|
assertEquals(7, linear.getFloat(3), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1609,7 +1612,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
||||||
INDArray neg = Transforms.neg(n);
|
INDArray neg = Transforms.neg(n);
|
||||||
assertEquals(assertion, neg,getFailureMessage());
|
assertEquals(assertion, neg,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1622,13 +1625,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
double assertion = 5.47722557505;
|
double assertion = 5.47722557505;
|
||||||
double norm3 = n.norm2Number().doubleValue();
|
double norm3 = n.norm2Number().doubleValue();
|
||||||
assertEquals(assertion, norm3, 1e-1,getFailureMessage());
|
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||||
INDArray row1 = row.getRow(1);
|
INDArray row1 = row.getRow(1);
|
||||||
double norm2 = row1.norm2Number().doubleValue();
|
double norm2 = row1.norm2Number().doubleValue();
|
||||||
double assertion2 = 5.0f;
|
double assertion2 = 5.0f;
|
||||||
assertEquals(assertion2, norm2, 1e-1,getFailureMessage());
|
assertEquals(assertion2, norm2, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
Nd4j.setDataType(initialType);
|
Nd4j.setDataType(initialType);
|
||||||
}
|
}
|
||||||
|
@ -1640,14 +1643,14 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
float assertion = 5.47722557505f;
|
float assertion = 5.47722557505f;
|
||||||
float norm3 = n.norm2Number().floatValue();
|
float norm3 = n.norm2Number().floatValue();
|
||||||
assertEquals(assertion, norm3, 1e-1,getFailureMessage());
|
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
INDArray row = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
|
INDArray row = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||||
INDArray row1 = row.getRow(1);
|
INDArray row1 = row.getRow(1);
|
||||||
float norm2 = row1.norm2Number().floatValue();
|
float norm2 = row1.norm2Number().floatValue();
|
||||||
float assertion2 = 5.0f;
|
float assertion2 = 5.0f;
|
||||||
assertEquals(assertion2, norm2, 1e-1,getFailureMessage());
|
assertEquals(assertion2, norm2, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1659,7 +1662,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
double sim = Transforms.cosineSim(vec1, vec2);
|
double sim = Transforms.cosineSim(vec1, vec2);
|
||||||
assertEquals(1, sim, 1e-1,getFailureMessage());
|
assertEquals(1, sim, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
|
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
|
||||||
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
|
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
|
||||||
|
@ -1675,14 +1678,14 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
double assertion = 2;
|
double assertion = 2;
|
||||||
INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8});
|
INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8});
|
||||||
INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer);
|
INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer);
|
||||||
assertEquals(answer, scal,getFailureMessage());
|
assertEquals(answer, scal,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||||
INDArray row1 = row.getRow(1);
|
INDArray row1 = row.getRow(1);
|
||||||
double assertion2 = 5.0;
|
double assertion2 = 5.0;
|
||||||
INDArray answer2 = Nd4j.create(new double[] {15, 20});
|
INDArray answer2 = Nd4j.create(new double[] {15, 20});
|
||||||
INDArray scal2 = Nd4j.getBlasWrapper().scal(assertion2, row1);
|
INDArray scal2 = Nd4j.getBlasWrapper().scal(assertion2, row1);
|
||||||
assertEquals(answer2, scal2,getFailureMessage());
|
assertEquals(answer2, scal2,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2076,17 +2079,17 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray innerProduct = n.mmul(transposed);
|
INDArray innerProduct = n.mmul(transposed);
|
||||||
|
|
||||||
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
|
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
|
||||||
assertEquals(scalar, innerProduct,getFailureMessage());
|
assertEquals(scalar, innerProduct,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray outerProduct = transposed.mmul(n);
|
INDArray outerProduct = transposed.mmul(n);
|
||||||
assertEquals(true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape()),getFailureMessage());
|
assertEquals(true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape()),getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
INDArray three = Nd4j.create(new double[] {3, 4});
|
INDArray three = Nd4j.create(new double[] {3, 4});
|
||||||
INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2});
|
INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2});
|
||||||
INDArray sliceRow = test.slice(0).getRow(1);
|
INDArray sliceRow = test.slice(0).getRow(1);
|
||||||
assertEquals(three, sliceRow,getFailureMessage());
|
assertEquals(three, sliceRow,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray twoSix = Nd4j.create(new double[] {2, 6}, new long[] {2, 1});
|
INDArray twoSix = Nd4j.create(new double[] {2, 6}, new long[] {2, 1});
|
||||||
INDArray threeTwoSix = three.mmul(twoSix);
|
INDArray threeTwoSix = three.mmul(twoSix);
|
||||||
|
@ -2114,7 +2117,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray k1 = n1.transpose();
|
INDArray k1 = n1.transpose();
|
||||||
|
|
||||||
INDArray testVectorVector = k1.mmul(n1);
|
INDArray testVectorVector = k1.mmul(n1);
|
||||||
assertEquals(vectorVector, testVectorVector,getFailureMessage());
|
assertEquals(vectorVector, testVectorVector,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -2204,7 +2207,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(linear.getDouble(0, 1), 1, 1e-1);
|
assertEquals(linear.getDouble(0, 1), 1, 1e-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSize(Nd4jBackend backend) {
|
public void testSize(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
INDArray arr = Nd4j.create(4, 5);
|
INDArray arr = Nd4j.create(4, 5);
|
||||||
|
@ -2357,7 +2361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void testTensorDot(Nd4jBackend backend) {
|
public void testTensorDot(Nd4jBackend backend) {
|
||||||
INDArray oneThroughSixty = Nd4j.arange(60).reshape(3, 4, 5).castTo(DataType.DOUBLE);
|
INDArray oneThroughSixty = Nd4j.arange(60).reshape(3, 4, 5).castTo(DataType.DOUBLE);
|
||||||
|
@ -3051,10 +3056,10 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
public void testMeans(Nd4jBackend backend) {
|
public void testMeans(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
INDArray mean1 = a.mean(1);
|
INDArray mean1 = a.mean(1);
|
||||||
assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage());
|
assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage(backend));
|
||||||
assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage());
|
assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage(backend));
|
||||||
assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3063,9 +3068,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSums(Nd4jBackend backend) {
|
public void testSums(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage());
|
assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage(backend));
|
||||||
assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage());
|
assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage(backend));
|
||||||
assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -3438,7 +3443,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void largeInstantiation(Nd4jBackend backend) {
|
public void largeInstantiation(Nd4jBackend backend) {
|
||||||
Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024 - 1)); // Still works; this can even be called as often as I want, allowing me even to spill over on disk
|
Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024 - 1)); // Still works; this can even be called as often as I want, allowing me even to spill over on disk
|
||||||
|
@ -3487,7 +3493,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(cSum, fSum); //Expect: 4,6. Getting [4, 4] for f order
|
assertEquals(cSum, fSum); //Expect: 4,6. Getting [4, 4] for f order
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled //not relevant anymore
|
@Disabled //not relevant anymore
|
||||||
public void testAssignMixedC(Nd4jBackend backend) {
|
public void testAssignMixedC(Nd4jBackend backend) {
|
||||||
int[] shape1 = {3, 2, 2, 2, 2, 2};
|
int[] shape1 = {3, 2, 2, 2, 2, 2};
|
||||||
|
@ -3787,7 +3794,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(assertion, result);
|
assertEquals(assertion, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsValidation1(Nd4jBackend backend) {
|
public void testPullRowsValidation1(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2});
|
Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2});
|
||||||
|
@ -3795,7 +3803,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsValidation2(Nd4jBackend backend) {
|
public void testPullRowsValidation2(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2});
|
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2});
|
||||||
|
@ -3803,7 +3812,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsValidation3(Nd4jBackend backend) {
|
public void testPullRowsValidation3(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10});
|
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10});
|
||||||
|
@ -3811,7 +3821,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsValidation4(Nd4jBackend backend) {
|
public void testPullRowsValidation4(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3});
|
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3});
|
||||||
|
@ -3819,7 +3830,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsValidation5(Nd4jBackend backend) {
|
public void testPullRowsValidation5(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e');
|
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e');
|
||||||
|
@ -4975,7 +4987,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTadReduce3_5(Nd4jBackend backend) {
|
public void testTadReduce3_5(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
INDArray initial = Nd4j.create(5, 10);
|
INDArray initial = Nd4j.create(5, 10);
|
||||||
|
@ -6004,7 +6017,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void testLogExpSum1(Nd4jBackend backend) {
|
public void testLogExpSum1(Nd4jBackend backend) {
|
||||||
INDArray matrix = Nd4j.create(3, 3);
|
INDArray matrix = Nd4j.create(3, 3);
|
||||||
|
@ -6019,7 +6033,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void testLogExpSum2(Nd4jBackend backend) {
|
public void testLogExpSum2(Nd4jBackend backend) {
|
||||||
INDArray row = Nd4j.create(new double[]{1, 2, 3});
|
INDArray row = Nd4j.create(new double[]{1, 2, 3});
|
||||||
|
@ -6246,7 +6261,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReshapeFailure(Nd4jBackend backend) {
|
public void testReshapeFailure(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2);
|
val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2);
|
||||||
|
@ -6345,7 +6361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(new long[]{3, 2}, newShape.shape());
|
assertArrayEquals(new long[]{3, 2}, newShape.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTranspose1(Nd4jBackend backend) {
|
public void testTranspose1(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
|
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
|
||||||
|
@ -6360,7 +6377,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTranspose2(Nd4jBackend backend) {
|
public void testTranspose2(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
val scalar = Nd4j.scalar(2.f);
|
val scalar = Nd4j.scalar(2.f);
|
||||||
|
@ -6375,7 +6393,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
//@Disabled
|
//@Disabled
|
||||||
public void testMatmul_128by256(Nd4jBackend backend) {
|
public void testMatmul_128by256(Nd4jBackend backend) {
|
||||||
val mA = Nd4j.create(128, 156).assign(1.0f);
|
val mA = Nd4j.create(128, 156).assign(1.0f);
|
||||||
|
@ -6647,7 +6666,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp1, out1);
|
assertEquals(exp1, out1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBadReduce3Call(Nd4jBackend backend) {
|
public void testBadReduce3Call(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
val x = Nd4j.create(400,20);
|
val x = Nd4j.create(400,20);
|
||||||
|
@ -7392,8 +7412,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(ez, z);
|
assertEquals(ez, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
public void testBroadcastInvalid(){
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
public void testBroadcastInvalid() {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
INDArray arr1 = Nd4j.ones(3,4,1);
|
INDArray arr1 = Nd4j.ones(3,4,1);
|
||||||
|
|
||||||
|
@ -7656,7 +7677,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp, array);
|
assertEquals(exp, array);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScatterUpdateShortcut_f1(Nd4jBackend backend) {
|
public void testScatterUpdateShortcut_f1(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
val array = Nd4j.create(DataType.FLOAT, 5, 2);
|
val array = Nd4j.create(DataType.FLOAT, 5, 2);
|
||||||
|
@ -8041,7 +8063,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp, out); //Failing here
|
assertEquals(exp, out); //Failing here
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsFailure(Nd4jBackend backend) {
|
public void testPullRowsFailure(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
val idxs = new int[]{0,2,3,4};
|
val idxs = new int[]{0,2,3,4};
|
||||||
|
@ -8144,7 +8167,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp1, out1); //This is OK
|
assertEquals(exp1, out1); //This is OK
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutRowValidation(Nd4jBackend backend) {
|
public void testPutRowValidation(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
val matrix = Nd4j.create(5, 10);
|
val matrix = Nd4j.create(5, 10);
|
||||||
|
@ -8155,7 +8179,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutColumnValidation(Nd4jBackend backend) {
|
public void testPutColumnValidation(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
val matrix = Nd4j.create(5, 10);
|
val matrix = Nd4j.create(5, 10);
|
||||||
|
@ -8236,7 +8261,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScalarEq(){
|
public void testScalarEq(Nd4jBackend backend){
|
||||||
INDArray scalarRank2 = Nd4j.scalar(10.0).reshape(1,1);
|
INDArray scalarRank2 = Nd4j.scalar(10.0).reshape(1,1);
|
||||||
INDArray scalarRank1 = Nd4j.scalar(10.0).reshape(1);
|
INDArray scalarRank1 = Nd4j.scalar(10.0).reshape(1);
|
||||||
INDArray scalarRank0 = Nd4j.scalar(10.0);
|
INDArray scalarRank0 = Nd4j.scalar(10.0);
|
||||||
|
@ -8273,7 +8298,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testType1(@TempDir Path testDir) throws IOException {
|
@Disabled
|
||||||
|
public void testType1(Nd4jBackend backend) throws IOException {
|
||||||
for (int i = 0; i < 10; ++i) {
|
for (int i = 0; i < 10; ++i) {
|
||||||
INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100});
|
INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100});
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
|
@ -8295,7 +8321,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOnes(){
|
public void testOnes(Nd4jBackend backend){
|
||||||
INDArray arr = Nd4j.ones();
|
INDArray arr = Nd4j.ones();
|
||||||
INDArray arr2 = Nd4j.ones(DataType.LONG);
|
INDArray arr2 = Nd4j.ones(DataType.LONG);
|
||||||
assertEquals(0, arr.rank());
|
assertEquals(0, arr.rank());
|
||||||
|
@ -8306,7 +8332,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testZeros(){
|
public void testZeros(Nd4jBackend backend){
|
||||||
INDArray arr = Nd4j.zeros();
|
INDArray arr = Nd4j.zeros();
|
||||||
INDArray arr2 = Nd4j.zeros(DataType.LONG);
|
INDArray arr2 = Nd4j.zeros(DataType.LONG);
|
||||||
assertEquals(0, arr.rank());
|
assertEquals(0, arr.rank());
|
||||||
|
@ -8317,7 +8343,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testType2(@TempDir Path testDir) throws IOException {
|
@Disabled
|
||||||
|
public void testType2(Nd4jBackend backend) throws IOException {
|
||||||
for (int i = 0; i < 10; ++i) {
|
for (int i = 0; i < 10; ++i) {
|
||||||
INDArray in1 = Nd4j.ones(DataType.UINT16);
|
INDArray in1 = Nd4j.ones(DataType.UINT16);
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
|
|
|
@ -23,6 +23,7 @@ package org.nd4j.linalg;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -58,11 +59,12 @@ public class ToStringTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testToStringScalars(){
|
@Disabled
|
||||||
|
public void testToStringScalars(Nd4jBackend backend){
|
||||||
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
|
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
|
||||||
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};
|
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};
|
||||||
|
|
||||||
for(int dt=0; dt<5; dt++ ) {
|
for(int dt = 0; dt < 5; dt++) {
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < 5; i++) {
|
||||||
long[] shape = ArrayUtil.nTimes(i, 1L);
|
long[] shape = ArrayUtil.nTimes(i, 1L);
|
||||||
INDArray scalar = Nd4j.scalar(1.0f).castTo(dataTypes[dt]).reshape(shape);
|
INDArray scalar = Nd4j.scalar(1.0f).castTo(dataTypes[dt]).reshape(shape);
|
||||||
|
|
|
@ -64,7 +64,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@Disabled
|
@Disabled
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@ -79,7 +78,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@Disabled
|
@Disabled
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@ -100,7 +98,8 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCreateNpy3(Nd4jBackend backend) throws Exception {
|
public void testCreateNpy3(Nd4jBackend backend) throws Exception {
|
||||||
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
|
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
|
||||||
assertEquals(8, arrCreate.length());
|
assertEquals(8, arrCreate.length());
|
||||||
|
@ -111,8 +110,9 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(arrCreate.data().address(), pointer.address());
|
assertEquals(arrCreate.data().address(), pointer.address());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@Disabled // this is endless test
|
@Disabled // this is endless test
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEndlessAllocation(Nd4jBackend backend) {
|
public void testEndlessAllocation(Nd4jBackend backend) {
|
||||||
Nd4j.getEnvironment().setMaxSpecialMemory(1);
|
Nd4j.getEnvironment().setMaxSpecialMemory(1);
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -121,9 +121,10 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@Disabled("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time")
|
@Disabled("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time")
|
||||||
public void testAllocationLimits() throws Exception {
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
public void testAllocationLimits(Nd4jBackend backend) throws Exception {
|
||||||
Nd4j.create(1);
|
Nd4j.create(1);
|
||||||
|
|
||||||
val origDeviceLimit = Nd4j.getEnvironment().getDeviceLimit(0);
|
val origDeviceLimit = Nd4j.getEnvironment().getDeviceLimit(0);
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api;
|
package org.nd4j.linalg.api;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||||
|
|
|
@ -59,7 +59,7 @@ public class Level1Test extends BaseNd4jTestWithBackends {
|
||||||
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
INDArray row = matrix.getRow(1);
|
INDArray row = matrix.getRow(1);
|
||||||
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
|
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
|
||||||
assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage());
|
assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -70,8 +70,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
||||||
/**
|
/**
|
||||||
* Testing level1 blas
|
* Testing level1 blas
|
||||||
*/
|
*/
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBlasValidation1(Nd4jBackend backend) {
|
public void testBlasValidation1(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
|
@ -89,8 +88,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
||||||
/**
|
/**
|
||||||
* Testing level2 blas
|
* Testing level2 blas
|
||||||
*/
|
*/
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBlasValidation2(Nd4jBackend backend) {
|
public void testBlasValidation2(Nd4jBackend backend) {
|
||||||
assertThrows(RuntimeException.class,() -> {
|
assertThrows(RuntimeException.class,() -> {
|
||||||
|
@ -109,8 +107,7 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
||||||
/**
|
/**
|
||||||
* Testing level3 blas
|
* Testing level3 blas
|
||||||
*/
|
*/
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBlasValidation3(Nd4jBackend backend) {
|
public void testBlasValidation3(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
|
|
|
@ -88,7 +88,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
float[] d1 = new float[] {1, 2, 3, 4};
|
float[] d1 = new float[] {1, 2, 3, 4};
|
||||||
DataBuffer d = Nd4j.createBuffer(d1);
|
DataBuffer d = Nd4j.createBuffer(d1);
|
||||||
float[] d2 = d.asFloat();
|
float[] d2 = d.asFloat();
|
||||||
assertArrayEquals( d1, d2, 1e-1f,getFailureMessage());
|
assertArrayEquals( d1, d2, 1e-1f,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
d.put(0, 0.0);
|
d.put(0, 0.0);
|
||||||
float[] result = new float[] {0, 2, 3, 4};
|
float[] result = new float[] {0, 2, 3, 4};
|
||||||
d1 = d.asFloat();
|
d1 = d.asFloat();
|
||||||
assertArrayEquals(d1, result, 1e-1f,getFailureMessage());
|
assertArrayEquals(d1, result, 1e-1f,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -156,12 +156,12 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
||||||
float[] get = buffer.getFloatsAt(0, 3);
|
float[] get = buffer.getFloatsAt(0, 3);
|
||||||
float[] data = new float[] {1, 2, 3};
|
float[] data = new float[] {1, 2, 3};
|
||||||
assertArrayEquals(get, data, 1e-1f,getFailureMessage());
|
assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
float[] get2 = buffer.asFloat();
|
float[] get2 = buffer.asFloat();
|
||||||
float[] allData = buffer.getFloatsAt(0, (int) buffer.length());
|
float[] allData = buffer.getFloatsAt(0, (int) buffer.length());
|
||||||
assertArrayEquals(get2, allData, 1e-1f,getFailureMessage());
|
assertArrayEquals(get2, allData, 1e-1f,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -173,13 +173,13 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
||||||
float[] get = buffer.getFloatsAt(1, 3);
|
float[] get = buffer.getFloatsAt(1, 3);
|
||||||
float[] data = new float[] {2, 3, 4};
|
float[] data = new float[] {2, 3, 4};
|
||||||
assertArrayEquals(get, data, 1e-1f,getFailureMessage());
|
assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
float[] allButLast = new float[] {2, 3, 4, 5};
|
float[] allButLast = new float[] {2, 3, 4, 5};
|
||||||
|
|
||||||
float[] allData = buffer.getFloatsAt(1, (int) buffer.length());
|
float[] allData = buffer.getFloatsAt(1, (int) buffer.length());
|
||||||
assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage());
|
assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -190,7 +190,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
public void testAsBytes(Nd4jBackend backend) {
|
public void testAsBytes(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(5);
|
INDArray arr = Nd4j.create(5);
|
||||||
byte[] d = arr.data().asBytes();
|
byte[] d = arr.data().asBytes();
|
||||||
assertEquals(4 * 5, d.length,getFailureMessage());
|
assertEquals(4 * 5, d.length,getFailureMessage(backend));
|
||||||
INDArray rand = Nd4j.rand(3, 3);
|
INDArray rand = Nd4j.rand(3, 3);
|
||||||
rand.data().asBytes();
|
rand.data().asBytes();
|
||||||
|
|
||||||
|
|
|
@ -20,26 +20,18 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.indexing;
|
package org.nd4j.linalg.api.indexing;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.common.util.ArrayUtil;
|
||||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
import org.nd4j.linalg.indexing.*;
|
||||||
import org.nd4j.linalg.indexing.IntervalIndex;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndexAll;
|
|
||||||
import org.nd4j.linalg.indexing.NewAxis;
|
|
||||||
import org.nd4j.linalg.indexing.PointIndex;
|
|
||||||
import org.nd4j.linalg.indexing.SpecifiedIndex;
|
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import org.nd4j.common.util.ArrayUtil;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
@ -56,22 +48,22 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNegativeBounds() {
|
public void testNegativeBounds(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
|
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
|
||||||
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
|
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
|
||||||
INDArray get = arr.get(NDArrayIndex.all(),interval);
|
INDArray get = arr.get(NDArrayIndex.all(),interval);
|
||||||
INDArray assertion = Nd4j.create(new double[][]{
|
INDArray assertion = Nd4j.create(new double[][]{
|
||||||
{1,2,3},
|
{1,2,3},
|
||||||
{6,7,8}
|
{6,7,8}
|
||||||
});
|
});
|
||||||
assertEquals(assertion,get);
|
assertEquals(assertion,get);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNewAxis() {
|
public void testNewAxis(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
|
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
|
||||||
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
|
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
|
||||||
long[] shapeAssertion = {3, 2, 1, 1, 2};
|
long[] shapeAssertion = {3, 2, 1, 1, 2};
|
||||||
|
@ -79,9 +71,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void broadcastBug() {
|
public void broadcastBug(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
|
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
|
||||||
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
||||||
|
|
||||||
|
@ -91,9 +83,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIntervalsIn3D() {
|
public void testIntervalsIn3D(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
||||||
INDArray rest = arr.get(interval(1, 2), interval(0, 2), interval(0, 2));
|
INDArray rest = arr.get(interval(1, 2), interval(0, 2), interval(0, 2));
|
||||||
|
@ -101,9 +93,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSmallInterval() {
|
public void testSmallInterval(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
||||||
INDArray rest = arr.get(interval(1, 2), all(), all());
|
INDArray rest = arr.get(interval(1, 2), all(), all());
|
||||||
|
@ -111,9 +103,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAllWithNewAxisAndInterval() {
|
public void testAllWithNewAxisAndInterval(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||||
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
|
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
|
||||||
|
|
||||||
|
@ -121,9 +113,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(assertion2, get2);
|
assertEquals(assertion2, get2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAllWithNewAxisInMiddle() {
|
public void testAllWithNewAxisInMiddle(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||||
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
|
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
|
||||||
|
|
||||||
|
@ -131,20 +123,20 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(assertion2, get2);
|
assertEquals(assertion2, get2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAllWithNewAxis() {
|
public void testAllWithNewAxis(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||||
INDArray get = arr.get(newAxis(), all(), point(1));
|
INDArray get = arr.get(newAxis(), all(), point(1));
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{4, 5, 6}, {10, 11, 12}, {16, 17, 18}, {22, 23, 24}})
|
INDArray assertion = Nd4j.create(new double[][] {{4, 5, 6}, {10, 11, 12}, {16, 17, 18}, {22, 23, 24}})
|
||||||
.reshape(1, 4, 3);
|
.reshape(1, 4, 3);
|
||||||
assertEquals(assertion, get);
|
assertEquals(assertion, get);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIndexingWithMmul() {
|
public void testIndexingWithMmul(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
|
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
|
||||||
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
||||||
// System.out.println(b);
|
// System.out.println(b);
|
||||||
|
@ -154,9 +146,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(assertion, c);
|
assertEquals(assertion, c);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPointPointInterval() {
|
public void testPointPointInterval(Nd4jBackend backend) {
|
||||||
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
|
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
|
||||||
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
|
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{5, 6}, {8, 9}});
|
INDArray assertion = Nd4j.create(new double[][] {{5, 6}, {8, 9}});
|
||||||
|
@ -164,9 +156,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(assertion, get);
|
assertEquals(assertion, get);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIntervalLowerBound() {
|
public void testIntervalLowerBound(Nd4jBackend backend) {
|
||||||
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||||
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
|
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}});
|
INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}});
|
||||||
|
@ -176,9 +168,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetPointRowVector() {
|
public void testGetPointRowVector(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
|
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
|
||||||
|
|
||||||
INDArray arr2 = arr.get(point(0), interval(0, 100));
|
INDArray arr2 = arr.get(point(0), interval(0, 100));
|
||||||
|
@ -187,9 +179,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2);
|
assertEquals(Nd4j.linspace(1, 100, 100, DataType.DOUBLE), arr2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSpecifiedIndexVector() {
|
public void testSpecifiedIndexVector(Nd4jBackend backend) {
|
||||||
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
|
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
|
||||||
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
|
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
|
||||||
INDArray get = rootMatrix.get(all(), new SpecifiedIndex(0, 2));
|
INDArray get = rootMatrix.get(all(), new SpecifiedIndex(0, 2));
|
||||||
|
@ -205,9 +197,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutRowIndexing() {
|
public void testPutRowIndexing(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.ones(1, 10);
|
INDArray arr = Nd4j.ones(1, 10);
|
||||||
INDArray row = Nd4j.create(1, 10);
|
INDArray row = Nd4j.create(1, 10);
|
||||||
|
|
||||||
|
@ -216,9 +208,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(arr, row);
|
assertEquals(arr, row);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVectorIndexing2() {
|
public void testVectorIndexing2(Nd4jBackend backend) {
|
||||||
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
|
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
|
||||||
INDArray assertion = Nd4j.create(new double[] {2, 4});
|
INDArray assertion = Nd4j.create(new double[] {2, 4});
|
||||||
assertEquals(assertion, wholeVector);
|
assertEquals(assertion, wholeVector);
|
||||||
|
@ -232,9 +224,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOffsetsC() {
|
public void testOffsetsC(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
|
assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
|
||||||
assertEquals(3, NDArrayIndex.offset(arr, point(1), point(1)));
|
assertEquals(3, NDArrayIndex.offset(arr, point(1), point(1)));
|
||||||
|
@ -249,9 +241,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIndexFor() {
|
public void testIndexFor(Nd4jBackend backend) {
|
||||||
long[] shape = {1, 2};
|
long[] shape = {1, 2};
|
||||||
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
|
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
|
||||||
for (int i = 0; i < indexes.length; i++) {
|
for (int i = 0; i < indexes.length; i++) {
|
||||||
|
@ -259,9 +251,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetScalar() {
|
public void testGetScalar(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
||||||
INDArray d = arr.get(point(1));
|
INDArray d = arr.get(point(1));
|
||||||
assertTrue(d.isScalar());
|
assertTrue(d.isScalar());
|
||||||
|
@ -269,26 +261,26 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVectorIndexing() {
|
public void testVectorIndexing(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
|
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
|
||||||
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
|
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
|
||||||
INDArray viewTest = arr.get(point(0), interval(1, 5));
|
INDArray viewTest = arr.get(point(0), interval(1, 5));
|
||||||
assertEquals(assertion, viewTest);
|
assertEquals(assertion, viewTest);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNegativeIndices() {
|
public void testNegativeIndices(Nd4jBackend backend) {
|
||||||
INDArray test = Nd4j.create(10, 10, 10);
|
INDArray test = Nd4j.create(10, 10, 10);
|
||||||
test.putScalar(new int[] {0, 0, -1}, 1.0);
|
test.putScalar(new int[] {0, 0, -1}, 1.0);
|
||||||
assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber());
|
assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetIndices2d() {
|
public void testGetIndices2d(Nd4jBackend backend) {
|
||||||
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
|
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
|
||||||
INDArray firstRow = twoByTwo.getRow(0);
|
INDArray firstRow = twoByTwo.getRow(0);
|
||||||
INDArray secondRow = twoByTwo.getRow(1);
|
INDArray secondRow = twoByTwo.getRow(1);
|
||||||
|
@ -305,9 +297,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement);
|
assertEquals(Nd4j.create(new double[] {4}, new int[]{1,1}), individualElement);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetRow() {
|
public void testGetRow(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
|
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
|
||||||
int[] toGet = {0, 1};
|
int[] toGet = {0, 1};
|
||||||
|
@ -323,9 +315,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetRowEdgeCase() {
|
public void testGetRowEdgeCase(Nd4jBackend backend) {
|
||||||
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
||||||
INDArray get = rowVec.getRow(0); //Returning shape [1,1]
|
INDArray get = rowVec.getRow(0); //Returning shape [1,1]
|
||||||
|
|
||||||
|
@ -333,9 +325,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(rowVec, get);
|
assertEquals(rowVec, get);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetColumnEdgeCase() {
|
public void testGetColumnEdgeCase(Nd4jBackend backend) {
|
||||||
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
|
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
|
||||||
INDArray get = colVec.getColumn(0); //Returning shape [1,1]
|
INDArray get = colVec.getColumn(0); //Returning shape [1,1]
|
||||||
|
|
||||||
|
@ -343,9 +335,9 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(colVec, get);
|
assertEquals(colVec, get);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConcatColumns() {
|
public void testConcatColumns(Nd4jBackend backend) {
|
||||||
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
|
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
|
||||||
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
|
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
|
||||||
INDArray concat = Nd4j.concat(1, input1, input2);
|
INDArray concat = Nd4j.concat(1, input1, input2);
|
||||||
|
@ -353,18 +345,18 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(assertion, concat);
|
assertEquals(assertion, concat);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetIndicesVector() {
|
public void testGetIndicesVector(Nd4jBackend backend) {
|
||||||
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
|
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
|
||||||
INDArray test = Nd4j.create(new double[] {2, 3});
|
INDArray test = Nd4j.create(new double[] {2, 3});
|
||||||
INDArray result = line.get(point(0), interval(1, 3));
|
INDArray result = line.get(point(0), interval(1, 3));
|
||||||
assertEquals(test, result);
|
assertEquals(test, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testArangeMul() {
|
public void testArangeMul(Nd4jBackend backend) {
|
||||||
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
|
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
|
||||||
INDArrayIndex index = interval(0, 2);
|
INDArrayIndex index = interval(0, 2);
|
||||||
INDArray get = arange.get(index, index);
|
INDArray get = arange.get(index, index);
|
||||||
|
@ -374,7 +366,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(assertion, mul);
|
assertEquals(assertion, mul);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIndexingThorough(){
|
public void testIndexingThorough(){
|
||||||
long[] fullShape = {3,4,5,6,7};
|
long[] fullShape = {3,4,5,6,7};
|
||||||
|
@ -575,7 +567,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
return d;
|
return d;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void debugging(){
|
public void debugging(){
|
||||||
long[] inShape = {3,4};
|
long[] inShape = {3,4};
|
||||||
|
|
|
@ -46,12 +46,13 @@ import java.util.List;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
|
||||||
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
|
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void compareAfterWrite(Nd4jBackend backend) throws Exception {
|
||||||
int [] ranksToCheck = new int[] {0,1,2,3,4};
|
int [] ranksToCheck = new int[] {0,1,2,3,4};
|
||||||
for (int i = 0; i < ranksToCheck.length; i++) {
|
for (int i = 0; i < ranksToCheck.length; i++) {
|
||||||
// log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
// log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
||||||
|
@ -82,7 +83,7 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNd4jReadWriteText(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testNd4jReadWriteText(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
|
|
@ -38,11 +38,11 @@ import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
|
||||||
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
|
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void compareAfterWrite(Nd4jBackend backend) throws Exception {
|
||||||
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
|
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
|
||||||
for (int i = 0; i < ranksToCheck.length; i++) {
|
for (int i = 0; i < ranksToCheck.length; i++) {
|
||||||
log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
||||||
|
|
|
@ -22,6 +22,7 @@ package org.nd4j.linalg.broadcast;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -135,7 +136,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(e, z);
|
assertEquals(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void basicBroadcastFailureTest_1(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_1(Nd4jBackend backend) {
|
||||||
|
@ -146,7 +146,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void basicBroadcastFailureTest_2(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_2(Nd4jBackend backend) {
|
||||||
|
@ -158,7 +157,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void basicBroadcastFailureTest_3(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_3(Nd4jBackend backend) {
|
||||||
|
@ -170,16 +168,15 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void basicBroadcastFailureTest_4(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_4(Nd4jBackend backend) {
|
||||||
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
||||||
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
||||||
val z = x.addi(y);
|
val z = x.addi(y);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void basicBroadcastFailureTest_5(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_5(Nd4jBackend backend) {
|
||||||
|
@ -191,7 +188,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void basicBroadcastFailureTest_6(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_6(Nd4jBackend backend) {
|
||||||
|
@ -249,9 +245,9 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(y, z);
|
assertEquals(y, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void emptyBroadcastTest_2(Nd4jBackend backend) {
|
public void emptyBroadcastTest_2(Nd4jBackend backend) {
|
||||||
val x = Nd4j.create(DataType.FLOAT, 1, 2);
|
val x = Nd4j.create(DataType.FLOAT, 1, 2);
|
||||||
val y = Nd4j.create(DataType.FLOAT, 0, 2);
|
val y = Nd4j.create(DataType.FLOAT, 0, 2);
|
||||||
|
|
|
@ -37,7 +37,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
public class CompressionMagicTests extends BaseNd4jTestWithBackends {
|
public class CompressionMagicTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void setUp(Nd4jBackend backend) {
|
public void setUp() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,7 @@ import java.util.Set;
|
||||||
|
|
||||||
public class DeconvTests extends BaseNd4jTestWithBackends {
|
public class DeconvTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
|
@ -56,7 +57,7 @@ public class DeconvTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void compareKeras(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void compareKeras(Nd4jBackend backend) throws Exception {
|
||||||
File newFolder = testDir.toFile();
|
File newFolder = testDir.toFile();
|
||||||
new ClassPathResource("keras/deconv/").copyDirectory(newFolder);
|
new ClassPathResource("keras/deconv/").copyDirectory(newFolder);
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,8 @@ public class SpecialTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScalarShuffle1(Nd4jBackend backend) {
|
public void testScalarShuffle1(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
List<DataSet> listData = new ArrayList<>();
|
List<DataSet> listData = new ArrayList<>();
|
||||||
|
|
|
@ -195,7 +195,8 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp, arrayX);
|
assertEquals(exp, arrayX);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInplaceOp1(Nd4jBackend backend) {
|
public void testInplaceOp1(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
val arrayX = Nd4j.create(10, 10);
|
val arrayX = Nd4j.create(10, 10);
|
||||||
|
|
|
@ -41,10 +41,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
|
public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBalance(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testBalance(Nd4jBackend backend) throws Exception {
|
||||||
DataSetIterator iterator = new IrisDataSetIterator(10, 150);
|
DataSetIterator iterator = new IrisDataSetIterator(10, 150);
|
||||||
|
|
||||||
File minibatches = new File(testDir.toFile(),"mini-batch-dir");
|
File minibatches = new File(testDir.toFile(),"mini-batch-dir");
|
||||||
|
@ -62,7 +63,7 @@ public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMiniBatchBalanced(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testMiniBatchBalanced(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
int miniBatchSize = 100;
|
int miniBatchSize = 100;
|
||||||
DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150);
|
DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150);
|
||||||
|
|
|
@ -52,7 +52,9 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
|
||||||
|
|
||||||
public class DataSetTest extends BaseNd4jTestWithBackends {
|
public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@TempDir Path testDir;
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testViewIterator(Nd4jBackend backend) {
|
public void testViewIterator(Nd4jBackend backend) {
|
||||||
DataSetIterator iter = new ViewIterator(new IrisDataSetIterator(150, 150).next(), 10);
|
DataSetIterator iter = new ViewIterator(new IrisDataSetIterator(150, 150).next(), 10);
|
||||||
|
@ -106,9 +108,9 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSplitTestAndTrain (Nd4jBackend backend) {
|
public void testSplitTestAndTrain(Nd4jBackend backend) {
|
||||||
INDArray labels = FeatureUtil.toOutcomeMatrix(new int[] {0, 0, 0, 0, 0, 0, 0, 0}, 1);
|
INDArray labels = FeatureUtil.toOutcomeMatrix(new int[] {0, 0, 0, 0, 0, 0, 0, 0}, 1);
|
||||||
DataSet data = new DataSet(Nd4j.rand(8, 1), labels);
|
DataSet data = new DataSet(Nd4j.rand(8, 1), labels);
|
||||||
|
|
||||||
|
@ -116,7 +118,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(train.getTrain().getLabels().length(), 6);
|
assertEquals(train.getTrain().getLabels().length(), 6);
|
||||||
|
|
||||||
SplitTestAndTrain train2 = data.splitTestAndTrain(6, new Random(1));
|
SplitTestAndTrain train2 = data.splitTestAndTrain(6, new Random(1));
|
||||||
assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage());
|
assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage(backend));
|
||||||
|
|
||||||
DataSet x0 = new IrisDataSetIterator(150, 150).next();
|
DataSet x0 = new IrisDataSetIterator(150, 150).next();
|
||||||
SplitTestAndTrain testAndTrain = x0.splitTestAndTrain(10);
|
SplitTestAndTrain testAndTrain = x0.splitTestAndTrain(10);
|
||||||
|
@ -144,7 +146,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
SplitTestAndTrain testAndTrainRng = x2.splitTestAndTrain(10, rngHere);
|
SplitTestAndTrain testAndTrainRng = x2.splitTestAndTrain(10, rngHere);
|
||||||
|
|
||||||
assertArrayEquals(testAndTrainRng.getTrain().getFeatures().shape(),
|
assertArrayEquals(testAndTrainRng.getTrain().getFeatures().shape(),
|
||||||
testAndTrain.getTrain().getFeatures().shape());
|
testAndTrain.getTrain().getFeatures().shape());
|
||||||
assertEquals(testAndTrainRng.getTrain().getFeatures(), testAndTrain.getTrain().getFeatures());
|
assertEquals(testAndTrainRng.getTrain().getFeatures(), testAndTrain.getTrain().getFeatures());
|
||||||
assertEquals(testAndTrainRng.getTrain().getLabels(), testAndTrain.getTrain().getLabels());
|
assertEquals(testAndTrainRng.getTrain().getLabels(), testAndTrain.getTrain().getLabels());
|
||||||
|
|
||||||
|
@ -154,13 +156,13 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLabelCounts(Nd4jBackend backend) {
|
public void testLabelCounts(Nd4jBackend backend) {
|
||||||
DataSet x0 = new IrisDataSetIterator(150, 150).next();
|
DataSet x0 = new IrisDataSetIterator(150, 150).next();
|
||||||
assertEquals(0, x0.get(0).outcome(),getFailureMessage());
|
assertEquals(0, x0.get(0).outcome(),getFailureMessage(backend));
|
||||||
assertEquals( 0, x0.get(1).outcome(),getFailureMessage());
|
assertEquals( 0, x0.get(1).outcome(),getFailureMessage(backend));
|
||||||
assertEquals(2, x0.get(149).outcome(),getFailureMessage());
|
assertEquals(2, x0.get(149).outcome(),getFailureMessage(backend));
|
||||||
Map<Integer, Double> counts = x0.labelCounts();
|
Map<Integer, Double> counts = x0.labelCounts();
|
||||||
assertEquals(50, counts.get(0), 1e-1,getFailureMessage());
|
assertEquals(50, counts.get(0), 1e-1,getFailureMessage(backend));
|
||||||
assertEquals(50, counts.get(1), 1e-1,getFailureMessage());
|
assertEquals(50, counts.get(1), 1e-1,getFailureMessage(backend));
|
||||||
assertEquals(50, counts.get(2), 1e-1,getFailureMessage());
|
assertEquals(50, counts.get(2), 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -694,14 +696,14 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
INDArray expLabels3d = Nd4j.create(3, 3, 4);
|
INDArray expLabels3d = Nd4j.create(3, 3, 4);
|
||||||
expLabels3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
|
expLabels3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
|
||||||
l3d1);
|
l3d1);
|
||||||
expLabels3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
|
expLabels3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
|
||||||
NDArrayIndex.interval(0, 3)}, l3d2);
|
NDArrayIndex.interval(0, 3)}, l3d2);
|
||||||
INDArray expLM3d = Nd4j.create(3, 3, 4);
|
INDArray expLM3d = Nd4j.create(3, 3, 4);
|
||||||
expLM3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
|
expLM3d.put(new INDArrayIndex[] {interval(0,1), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)},
|
||||||
lm3d1);
|
lm3d1);
|
||||||
expLM3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
|
expLM3d.put(new INDArrayIndex[] {NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(),
|
||||||
NDArrayIndex.interval(0, 3)}, lm3d2);
|
NDArrayIndex.interval(0, 3)}, lm3d2);
|
||||||
|
|
||||||
|
|
||||||
DataSet merged3d = DataSet.merge(Arrays.asList(ds3d1, ds3d2));
|
DataSet merged3d = DataSet.merge(Arrays.asList(ds3d1, ds3d2));
|
||||||
|
@ -752,52 +754,52 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testShuffleNd(Nd4jBackend backend) {
|
public void testShuffleNd(Nd4jBackend backend) {
|
||||||
int numDims = 7;
|
int numDims = 7;
|
||||||
int nLabels = 3;
|
int nLabels = 3;
|
||||||
Random r = new Random();
|
Random r = new Random();
|
||||||
|
|
||||||
|
|
||||||
int[] shape = new int[numDims];
|
int[] shape = new int[numDims];
|
||||||
int entries = 1;
|
int entries = 1;
|
||||||
for (int i = 0; i < numDims; i++) {
|
for (int i = 0; i < numDims; i++) {
|
||||||
//randomly generating shapes bigger than 1
|
//randomly generating shapes bigger than 1
|
||||||
shape[i] = r.nextInt(4) + 2;
|
shape[i] = r.nextInt(4) + 2;
|
||||||
entries *= shape[i];
|
entries *= shape[i];
|
||||||
}
|
}
|
||||||
int labels = shape[0] * nLabels;
|
int labels = shape[0] * nLabels;
|
||||||
|
|
||||||
INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
|
INDArray ds_data = Nd4j.linspace(1, entries, entries, DataType.INT).reshape(shape);
|
||||||
INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
|
INDArray ds_labels = Nd4j.linspace(1, labels, labels, DataType.INT).reshape(shape[0], nLabels);
|
||||||
|
|
||||||
DataSet ds = new DataSet(ds_data, ds_labels);
|
DataSet ds = new DataSet(ds_data, ds_labels);
|
||||||
ds.shuffle();
|
ds.shuffle();
|
||||||
|
|
||||||
//Checking Nd dataset which is the data
|
//Checking Nd dataset which is the data
|
||||||
for (int dim = 1; dim < numDims; dim++) {
|
for (int dim = 1; dim < numDims; dim++) {
|
||||||
//get tensor along dimension - the order in every dimension but zero should be preserved
|
|
||||||
for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
|
|
||||||
//the difference between consecutive elements should be equal to the stride
|
|
||||||
for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
|
|
||||||
int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
|
|
||||||
int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
|
|
||||||
int f_element_diff = f_next_element - f_element;
|
|
||||||
assertEquals(f_element_diff, ds_data.stride(dim));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//Checking 2d, features
|
|
||||||
int dim = 1;
|
|
||||||
//get tensor along dimension - the order in every dimension but zero should be preserved
|
//get tensor along dimension - the order in every dimension but zero should be preserved
|
||||||
for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
|
for (int tensorNum = 0; tensorNum < ds_data.tensorsAlongDimension(dim); tensorNum++) {
|
||||||
//the difference between consecutive elements should be equal to the stride
|
//the difference between consecutive elements should be equal to the stride
|
||||||
for (int i = 0, j = 1; j < nLabels; i++, j++) {
|
for (int i = 0, j = 1; j < shape[dim]; i++, j++) {
|
||||||
int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
|
int f_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(i);
|
||||||
int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
|
int f_next_element = ds.getFeatures().tensorAlongDimension(tensorNum, dim).getInt(j);
|
||||||
int l_element_diff = l_next_element - l_element;
|
int f_element_diff = f_next_element - f_element;
|
||||||
assertEquals(l_element_diff, ds_labels.stride(dim));
|
assertEquals(f_element_diff, ds_data.stride(dim));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Checking 2d, features
|
||||||
|
int dim = 1;
|
||||||
|
//get tensor along dimension - the order in every dimension but zero should be preserved
|
||||||
|
for (int tensorNum = 0; tensorNum < ds_labels.tensorsAlongDimension(dim); tensorNum++) {
|
||||||
|
//the difference between consecutive elements should be equal to the stride
|
||||||
|
for (int i = 0, j = 1; j < nLabels; i++, j++) {
|
||||||
|
int l_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(i);
|
||||||
|
int l_next_element = ds.getLabels().tensorAlongDimension(tensorNum, dim).getInt(j);
|
||||||
|
int l_element_diff = l_next_element - l_element;
|
||||||
|
assertEquals(l_element_diff, ds_labels.stride(dim));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -936,9 +938,9 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
//Checking if the features and labels are equal
|
//Checking if the features and labels are equal
|
||||||
assertEquals(iDataSet.getFeatures(),
|
assertEquals(iDataSet.getFeatures(),
|
||||||
dsList.get(i).getFeatures().get(all(), all(), interval(0, minTSLength + i)));
|
dsList.get(i).getFeatures().get(all(), all(), interval(0, minTSLength + i)));
|
||||||
assertEquals(iDataSet.getLabels(),
|
assertEquals(iDataSet.getLabels(),
|
||||||
dsList.get(i).getLabels().get(all(), all(), interval(0, minTSLength + i)));
|
dsList.get(i).getLabels().get(all(), all(), interval(0, minTSLength + i)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -964,8 +966,8 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
for (boolean lMask : b) {
|
for (boolean lMask : b) {
|
||||||
|
|
||||||
DataSet ds = new DataSet((features ? f : null),
|
DataSet ds = new DataSet((features ? f : null),
|
||||||
(labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null),
|
(labels ? (labelsSameAsFeatures ? f : l) : null), (fMask ? fm : null),
|
||||||
(lMask ? lm : null));
|
(lMask ? lm : null));
|
||||||
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
DataOutputStream dos = new DataOutputStream(baos);
|
DataOutputStream dos = new DataOutputStream(baos);
|
||||||
|
@ -1009,7 +1011,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
boolean lMask = true;
|
boolean lMask = true;
|
||||||
|
|
||||||
DataSet ds = new DataSet((features ? f : null), (labels ? (labelsSameAsFeatures ? f : l) : null),
|
DataSet ds = new DataSet((features ? f : null), (labels ? (labelsSameAsFeatures ? f : l) : null),
|
||||||
(fMask ? fm : null), (lMask ? lm : null));
|
(fMask ? fm : null), (lMask ? lm : null));
|
||||||
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
DataOutputStream dos = new DataOutputStream(baos);
|
DataOutputStream dos = new DataOutputStream(baos);
|
||||||
|
@ -1098,7 +1100,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
|
public void testDataSetMetaDataSerialization(Nd4jBackend backend) throws IOException {
|
||||||
|
|
||||||
for(boolean withMeta : new boolean[]{false, true}) {
|
for(boolean withMeta : new boolean[]{false, true}) {
|
||||||
// create simple data set with meta data object
|
// create simple data set with meta data object
|
||||||
|
@ -1129,7 +1131,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend nd4jBackend) throws IOException {
|
public void testMultiDataSetMetaDataSerialization(Nd4jBackend nd4jBackend) throws IOException {
|
||||||
|
|
||||||
for(boolean withMeta : new boolean[]{false, true}) {
|
for(boolean withMeta : new boolean[]{false, true}) {
|
||||||
// create simple data set with meta data object
|
// create simple data set with meta data object
|
||||||
|
|
|
@ -106,7 +106,8 @@ public class KFoldIteratorTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void checkCornerCaseException(Nd4jBackend backend) {
|
public void checkCornerCaseException(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1),
|
DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1),
|
||||||
|
|
|
@ -21,27 +21,25 @@
|
||||||
package org.nd4j.linalg.dataset;
|
package org.nd4j.linalg.dataset;
|
||||||
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTestWithBackends {
|
public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMiniBatches(@TempDir Path testDir) throws Exception {
|
public void testMiniBatches(Nd4jBackend backend) throws Exception {
|
||||||
DataSet load = new IrisDataSetIterator(150, 150).next();
|
DataSet load = new IrisDataSetIterator(150, 150).next();
|
||||||
final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.toFile());
|
final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.toFile());
|
||||||
while (iter.hasNext())
|
while (iter.hasNext())
|
||||||
|
|
|
@ -39,8 +39,7 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_preConditionsIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
public void when_preConditionsIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||||
assertThrows(NullPointerException.class,() -> {
|
assertThrows(NullPointerException.class,() -> {
|
||||||
|
|
|
@ -41,8 +41,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_originalHeightIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_originalHeightIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
|
@ -51,8 +50,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_originalWidthIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_originalWidthIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
|
@ -61,8 +59,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_yStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_yStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
|
@ -71,8 +68,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_xStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_xStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
|
@ -81,8 +77,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
|
@ -91,8 +86,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
|
@ -101,8 +95,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
|
@ -111,8 +104,7 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||||
// Assemble
|
// Assemble
|
||||||
|
|
|
@ -39,7 +39,8 @@ public class PermuteDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||||
assertThrows(NullPointerException.class,() -> {
|
assertThrows(NullPointerException.class,() -> {
|
||||||
// Assemble
|
// Assemble
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.dataset.api.preprocessor;
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||||
|
@ -39,7 +38,8 @@ public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTestWithBacke
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||||
assertThrows(NullPointerException.class,() -> {
|
assertThrows(NullPointerException.class,() -> {
|
||||||
// Assemble
|
// Assemble
|
||||||
|
|
|
@ -139,7 +139,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
|
||||||
INDArray actualResult = data.mean(0);
|
INDArray actualResult = data.mean(0);
|
||||||
INDArray expectedResult = Nd4j.create(new double[] {3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6.,
|
INDArray expectedResult = Nd4j.create(new double[] {3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6.,
|
||||||
6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6.}, new int[] {2, 4, 4});
|
6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6.}, new int[] {2, 4, 4});
|
||||||
assertEquals(expectedResult, actualResult,getFailureMessage());
|
assertEquals(expectedResult, actualResult,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
|
||||||
INDArray actualResult = data.var(false, 0);
|
INDArray actualResult = data.var(false, 0);
|
||||||
INDArray expectedResult = Nd4j.create(new double[] {1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4.,
|
INDArray expectedResult = Nd4j.create(new double[] {1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4.,
|
||||||
4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4.}, new long[] {2, 4, 4});
|
4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4.}, new long[] {2, 4, 4});
|
||||||
assertEquals(expectedResult, actualResult,getFailureMessage());
|
assertEquals(expectedResult, actualResult,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
|
|
@ -83,8 +83,7 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAccessException_1(Nd4jBackend backend) {
|
public void testAccessException_1(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
|
@ -96,8 +95,7 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAccessException_2(Nd4jBackend backend) {
|
public void testAccessException_2(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
|
|
|
@ -384,7 +384,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp, arrayZ);
|
assertEquals(exp, arrayZ);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
|
||||||
public void testTypesValidation_1(Nd4jBackend backend) {
|
public void testTypesValidation_1(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG);
|
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG);
|
||||||
|
@ -397,7 +399,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTypesValidation_2(Nd4jBackend backend) {
|
public void testTypesValidation_2(Nd4jBackend backend) {
|
||||||
assertThrows(RuntimeException.class,() -> {
|
assertThrows(RuntimeException.class,() -> {
|
||||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||||
|
@ -412,7 +415,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTypesValidation_3(Nd4jBackend backend) {
|
public void testTypesValidation_3(Nd4jBackend backend) {
|
||||||
assertThrows(RuntimeException.class,() -> {
|
assertThrows(RuntimeException.class,() -> {
|
||||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||||
|
@ -422,6 +426,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTypesValidation_4(Nd4jBackend backend) {
|
public void testTypesValidation_4(Nd4jBackend backend) {
|
||||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||||
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.DOUBLE);
|
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.DOUBLE);
|
||||||
|
@ -485,7 +491,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBoolFloatCast2(){
|
public void testBoolFloatCast2(Nd4jBackend backend){
|
||||||
val first = Nd4j.zeros(DataType.FLOAT, 3, 5000);
|
val first = Nd4j.zeros(DataType.FLOAT, 3, 5000);
|
||||||
INDArray asBool = first.castTo(DataType.BOOL);
|
INDArray asBool = first.castTo(DataType.BOOL);
|
||||||
INDArray not = Transforms.not(asBool); //
|
INDArray not = Transforms.not(asBool); //
|
||||||
|
@ -516,7 +522,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAssignScalarSimple(){
|
public void testAssignScalarSimple(Nd4jBackend backend){
|
||||||
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
||||||
INDArray arr = Nd4j.scalar(dt, 10.0);
|
INDArray arr = Nd4j.scalar(dt, 10.0);
|
||||||
arr.assign(2.0);
|
arr.assign(2.0);
|
||||||
|
@ -526,7 +532,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimple(){
|
public void testSimple(Nd4jBackend backend){
|
||||||
Nd4j.create(1);
|
Nd4j.create(1);
|
||||||
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) {
|
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) {
|
||||||
// System.out.println("----- " + dt + " -----");
|
// System.out.println("----- " + dt + " -----");
|
||||||
|
@ -551,7 +557,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testWorkspaceBool(){
|
public void testWorkspaceBool(Nd4jBackend backend){
|
||||||
val conf = WorkspaceConfiguration.builder().minSize(10 * 1024 * 1024)
|
val conf = WorkspaceConfiguration.builder().minSize(10 * 1024 * 1024)
|
||||||
.overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE)
|
.overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE)
|
||||||
.policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL)
|
.policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL)
|
||||||
|
@ -559,7 +565,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS");
|
val ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, "WS");
|
||||||
|
|
||||||
for( int i=0; i<10; i++ ) {
|
for( int i = 0; i < 10; i++ ) {
|
||||||
try (val workspace = (Nd4jWorkspace)ws.notifyScopeEntered() ) {
|
try (val workspace = (Nd4jWorkspace)ws.notifyScopeEntered() ) {
|
||||||
val bool = Nd4j.create(DataType.BOOL, 1, 10);
|
val bool = Nd4j.create(DataType.BOOL, 1, 10);
|
||||||
val dbl = Nd4j.create(DataType.DOUBLE, 1, 10);
|
val dbl = Nd4j.create(DataType.DOUBLE, 1, 10);
|
||||||
|
@ -574,8 +580,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void testArrayCreationFromPointer(Nd4jBackend backend) {
|
public void testArrayCreationFromPointer(Nd4jBackend backend) {
|
||||||
val source = Nd4j.create(new double[]{1, 2, 3, 4, 5});
|
val source = Nd4j.create(new double[]{1, 2, 3, 4, 5});
|
||||||
|
|
||||||
|
|
|
@ -40,13 +40,13 @@ public class NativeBlasTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void setUp(Nd4jBackend backend) {
|
public void setUp() {
|
||||||
Nd4j.getExecutioner().enableDebugMode(true);
|
Nd4j.getExecutioner().enableDebugMode(true);
|
||||||
Nd4j.getExecutioner().enableVerboseMode(true);
|
Nd4j.getExecutioner().enableVerboseMode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
public void setDown(Nd4jBackend backend) {
|
public void setDown() {
|
||||||
Nd4j.getExecutioner().enableDebugMode(false);
|
Nd4j.getExecutioner().enableDebugMode(false);
|
||||||
Nd4j.getExecutioner().enableVerboseMode(false);
|
Nd4j.getExecutioner().enableVerboseMode(false);
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,18 +77,18 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
|
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
|
||||||
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
|
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
|
||||||
double sim = Transforms.cosineSim(vec1, vec2);
|
double sim = Transforms.cosineSim(vec1, vec2);
|
||||||
assertEquals( 1, sim, 1e-1,getFailureMessage());
|
assertEquals( 1, sim, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCosineDistance(){
|
public void testCosineDistance(Nd4jBackend backend){
|
||||||
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3});
|
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3});
|
||||||
INDArray vec2 = Nd4j.create(new float[] {3, 5, 7});
|
INDArray vec2 = Nd4j.create(new float[] {3, 5, 7});
|
||||||
// 1-17*sqrt(2/581)
|
// 1-17*sqrt(2/581)
|
||||||
double distance = Transforms.cosineDistance(vec1, vec2);
|
double distance = Transforms.cosineDistance(vec1, vec2);
|
||||||
assertEquals(0.0025851, distance, 1e-7,getFailureMessage());
|
assertEquals(0.0025851, distance, 1e-7,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -97,7 +97,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray arr = Nd4j.create(new double[] {55, 55});
|
INDArray arr = Nd4j.create(new double[] {55, 55});
|
||||||
INDArray arr2 = Nd4j.create(new double[] {60, 60});
|
INDArray arr2 = Nd4j.create(new double[] {60, 60});
|
||||||
double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).z().getDouble(0);
|
double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).z().getDouble(0);
|
||||||
assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage());
|
assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -137,7 +137,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi();
|
INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi();
|
||||||
INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6);
|
INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6);
|
||||||
Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1));
|
Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1));
|
||||||
assertEquals(scalarMax, postMax,getFailureMessage());
|
assertEquals(scalarMax, postMax,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -147,14 +147,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1));
|
Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1));
|
||||||
for (int i = 0; i < linspace.length(); i++) {
|
for (int i = 0; i < linspace.length(); i++) {
|
||||||
double val = linspace.getDouble(i);
|
double val = linspace.getDouble(i);
|
||||||
assertTrue( val >= 0 && val <= 1,getFailureMessage());
|
assertTrue( val >= 0 && val <= 1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||||
Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4));
|
Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4));
|
||||||
for (int i = 0; i < linspace2.length(); i++) {
|
for (int i = 0; i < linspace2.length(); i++) {
|
||||||
double val = linspace2.getDouble(i);
|
double val = linspace2.getDouble(i);
|
||||||
assertTrue( val >= 2 && val <= 4,getFailureMessage());
|
assertTrue( val >= 2 && val <= 4,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,7 +163,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
public void testNormMax(Nd4jBackend backend) {
|
public void testNormMax(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).z().getDouble(0);
|
double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).z().getDouble(0);
|
||||||
assertEquals(4, normMax, 1e-1,getFailureMessage());
|
assertEquals(4, normMax, 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -187,7 +187,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
public void testNorm2(Nd4jBackend backend) {
|
public void testNorm2(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).z().getDouble(0);
|
double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).z().getDouble(0);
|
||||||
assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage());
|
assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -198,7 +198,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray xDup = x.dup();
|
INDArray xDup = x.dup();
|
||||||
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
|
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
|
||||||
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
|
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
|
||||||
assertEquals(solution, x,getFailureMessage());
|
assertEquals(solution, x,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -221,13 +221,13 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray xDup = x.dup();
|
INDArray xDup = x.dup();
|
||||||
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
|
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
|
||||||
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
|
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
|
||||||
assertEquals(solution, x,getFailureMessage());
|
assertEquals(solution, x,getFailureMessage(backend));
|
||||||
Sum acc = new Sum(x.dup());
|
Sum acc = new Sum(x.dup());
|
||||||
opExecutioner.exec(acc);
|
opExecutioner.exec(acc);
|
||||||
assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
Prod prod = new Prod(x.dup());
|
Prod prod = new Prod(x.dup());
|
||||||
opExecutioner.exec(prod);
|
opExecutioner.exec(prod);
|
||||||
assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -275,7 +275,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
Variance variance = new Variance(x.dup(), true);
|
Variance variance = new Variance(x.dup(), true);
|
||||||
opExecutioner.exec(variance);
|
opExecutioner.exec(variance);
|
||||||
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -284,14 +284,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIamax(Nd4jBackend backend) {
|
public void testIamax(Nd4jBackend backend) {
|
||||||
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||||
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage());
|
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIamax2(Nd4jBackend backend) {
|
public void testIamax2(Nd4jBackend backend) {
|
||||||
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||||
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage());
|
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
|
||||||
val op = new ArgAmax(linspace);
|
val op = new ArgAmax(linspace);
|
||||||
|
|
||||||
int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0);
|
int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0);
|
||||||
|
@ -307,11 +307,11 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
Mean mean = new Mean(x);
|
Mean mean = new Mean(x);
|
||||||
opExecutioner.exec(mean);
|
opExecutioner.exec(mean);
|
||||||
assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
Variance variance = new Variance(x.dup(), true);
|
Variance variance = new Variance(x.dup(), true);
|
||||||
opExecutioner.exec(variance);
|
opExecutioner.exec(variance);
|
||||||
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -321,7 +321,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
||||||
val softMax = new SoftMax(arr);
|
val softMax = new SoftMax(arr);
|
||||||
opExecutioner.exec((CustomOp) softMax);
|
opExecutioner.exec((CustomOp) softMax);
|
||||||
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -332,7 +332,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
Pow pow = new Pow(oneThroughSix, 2);
|
Pow pow = new Pow(oneThroughSix, 2);
|
||||||
Nd4j.getExecutioner().exec(pow);
|
Nd4j.getExecutioner().exec(pow);
|
||||||
INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36});
|
INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36});
|
||||||
assertEquals(answer, pow.z(),getFailureMessage());
|
assertEquals(answer, pow.z(),getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -384,7 +384,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
Log log = new Log(slice);
|
Log log = new Log(slice);
|
||||||
opExecutioner.exec(log);
|
opExecutioner.exec(log);
|
||||||
INDArray assertion = Nd4j.create(new double[] {0., 1.09861229, 1.60943791});
|
INDArray assertion = Nd4j.create(new double[] {0., 1.09861229, 1.60943791});
|
||||||
assertEquals(assertion, slice,getFailureMessage());
|
assertEquals(assertion, slice,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -572,7 +572,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
expected[i] = (float) Math.exp(slice.getDouble(i));
|
expected[i] = (float) Math.exp(slice.getDouble(i));
|
||||||
Exp exp = new Exp(slice);
|
Exp exp = new Exp(slice);
|
||||||
opExecutioner.exec(exp);
|
opExecutioner.exec(exp);
|
||||||
assertEquals( Nd4j.create(expected), slice,getFailureMessage());
|
assertEquals( Nd4j.create(expected), slice,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -582,7 +582,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
||||||
val softMax = new SoftMax(arr);
|
val softMax = new SoftMax(arr);
|
||||||
opExecutioner.exec((CustomOp) softMax);
|
opExecutioner.exec((CustomOp) softMax);
|
||||||
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue