Unify nd4j test profiles, get rid of old modules, fix more parameter issues with junit 5 tests
parent
e0077c38a9
commit
ad4f47096c
|
@ -31,7 +31,7 @@ jobs:
|
||||||
protoc --version
|
protoc --version
|
||||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
windows-x86_64:
|
windows-x86_64:
|
||||||
runs-on: windows-2019
|
runs-on: windows-2019
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,5 +60,5 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ jobs:
|
||||||
protoc --version
|
protoc --version
|
||||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
windows-x86_64:
|
windows-x86_64:
|
||||||
runs-on: windows-2019
|
runs-on: windows-2019
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,5 +60,5 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
|
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
||||||
|
|
||||||
|
|
|
@ -34,5 +34,5 @@ jobs:
|
||||||
cmake --version
|
cmake --version
|
||||||
protoc --version
|
protoc --version
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptest-nd4j-native --also-make clean test
|
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Pnd4j-tests-cpu --also-make clean test
|
||||||
|
|
||||||
|
|
|
@ -109,10 +109,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -30,6 +30,8 @@ import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import org.nd4j.common.loader.FileBatch;
|
import org.nd4j.common.loader.FileBatch;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -40,13 +42,16 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
@DisplayName("File Batch Record Reader Test")
|
@DisplayName("File Batch Record Reader Test")
|
||||||
class FileBatchRecordReaderTest extends BaseND4JTest {
|
public class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@DisplayName("Test Csv")
|
@DisplayName("Test Csv")
|
||||||
void testCsv(@TempDir Path testDir) throws Exception {
|
void testCsv(Nd4jBackend backend) throws Exception {
|
||||||
// This is an unrealistic use case - one line/record per CSV
|
// This is an unrealistic use case - one line/record per CSV
|
||||||
File baseDir = testDir.toFile();
|
File baseDir = testDir.toFile();
|
||||||
List<File> fileList = new ArrayList<>();
|
List<File> fileList = new ArrayList<>();
|
||||||
|
@ -75,9 +80,10 @@ class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@DisplayName("Test Csv Sequence")
|
@DisplayName("Test Csv Sequence")
|
||||||
void testCsvSequence(@TempDir Path testDir) throws Exception {
|
void testCsvSequence(Nd4jBackend backend) throws Exception {
|
||||||
// CSV sequence - 3 lines per file, 10 files
|
// CSV sequence - 3 lines per file, 10 files
|
||||||
File baseDir = testDir.toFile();
|
File baseDir = testDir.toFile();
|
||||||
List<File> fileList = new ArrayList<>();
|
List<File> fileList = new ArrayList<>();
|
||||||
|
|
|
@ -21,7 +21,6 @@ package org.datavec.api.transform.ops;
|
||||||
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
|
@ -60,10 +60,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -119,10 +119,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -59,10 +59,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -57,10 +57,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -65,10 +65,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -61,25 +61,18 @@
|
||||||
<artifactId>nd4j-common</artifactId>
|
<artifactId>nd4j-common</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.datavec</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>datavec-geo</artifactId>
|
<artifactId>python4j-numpy</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-python</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -29,7 +29,6 @@ import org.datavec.api.transform.reduce.Reducer;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.schema.SequenceSchema;
|
import org.datavec.api.transform.schema.SequenceSchema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.datavec.python.PythonTransform;
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
@ -39,7 +38,6 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
|
||||||
import static java.time.Duration.ofMillis;
|
import static java.time.Duration.ofMillis;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTimeout;
|
import static org.junit.jupiter.api.Assertions.assertTimeout;
|
||||||
|
|
||||||
|
@ -166,37 +164,8 @@ class ExecutionTest {
|
||||||
List<List<Writable>> out = outRdd;
|
List<List<Writable>> out = outRdd;
|
||||||
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
||||||
out = new ArrayList<>(out);
|
out = new ArrayList<>(out);
|
||||||
Collections.sort(out, new Comparator<List<Writable>>() {
|
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
|
||||||
|
|
||||||
@Override
|
|
||||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
|
||||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
assertEquals(expOut, out);
|
assertEquals(expOut, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
|
|
||||||
@DisplayName("Test Python Execution Ndarray")
|
|
||||||
void testPythonExecutionNdarray() {
|
|
||||||
assertTimeout(ofMillis(60000), () -> {
|
|
||||||
Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
|
|
||||||
TransformProcess transformProcess = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
|
|
||||||
List<List<Writable>> functions = new ArrayList<>();
|
|
||||||
List<Writable> firstRow = new ArrayList<>();
|
|
||||||
INDArray firstArr = Nd4j.linspace(1, 4, 4);
|
|
||||||
INDArray secondArr = Nd4j.linspace(1, 4, 4);
|
|
||||||
firstRow.add(new NDArrayWritable(firstArr));
|
|
||||||
firstRow.add(new NDArrayWritable(secondArr));
|
|
||||||
functions.add(firstRow);
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
|
|
||||||
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
|
|
||||||
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
|
|
||||||
INDArray expected = Transforms.sin(firstArr);
|
|
||||||
INDArray secondExpected = Transforms.cos(secondArr);
|
|
||||||
assertEquals(expected, firstResult);
|
|
||||||
assertEquals(secondExpected, secondResult);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,155 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.local.transforms.transform;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnType;
|
|
||||||
import org.datavec.api.transform.Transform;
|
|
||||||
import org.datavec.api.transform.geo.LocationType;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
|
|
||||||
import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
|
|
||||||
import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
|
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author saudet
|
|
||||||
*/
|
|
||||||
public class TestGeoTransforms {
|
|
||||||
|
|
||||||
@BeforeAll
|
|
||||||
public static void beforeClass() throws Exception {
|
|
||||||
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
|
|
||||||
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
|
|
||||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void afterClass(){
|
|
||||||
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCoordinatesDistanceTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
assertEquals(4, out.numColumns());
|
|
||||||
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
|
|
||||||
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
|
|
||||||
out.getColumnTypes());
|
|
||||||
|
|
||||||
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
|
|
||||||
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
|
|
||||||
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
|
|
||||||
new DoubleWritable(Math.sqrt(160))),
|
|
||||||
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
|
|
||||||
new Text("10|5"))));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testIPAddressToCoordinatesTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
|
||||||
|
|
||||||
Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
|
|
||||||
assertEquals(1, out.getColumnMetaData().size());
|
|
||||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
|
||||||
|
|
||||||
String in = "81.2.69.160";
|
|
||||||
double latitude = 51.5142;
|
|
||||||
double longitude = -0.0931;
|
|
||||||
|
|
||||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
|
||||||
assertEquals(2, coordinates.length);
|
|
||||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
|
||||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
|
||||||
|
|
||||||
//Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
ObjectOutputStream oos = new ObjectOutputStream(baos);
|
|
||||||
oos.writeObject(transform);
|
|
||||||
|
|
||||||
byte[] bytes = baos.toByteArray();
|
|
||||||
|
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
|
||||||
ObjectInputStream ois = new ObjectInputStream(bais);
|
|
||||||
|
|
||||||
Transform deserialized = (Transform) ois.readObject();
|
|
||||||
writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
|
|
||||||
//System.out.println(Arrays.toString(coordinates));
|
|
||||||
assertEquals(2, coordinates.length);
|
|
||||||
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
|
|
||||||
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testIPAddressToLocationTransform() throws Exception {
|
|
||||||
Schema schema = new Schema.Builder().addColumnString("column").build();
|
|
||||||
LocationType[] locationTypes = LocationType.values();
|
|
||||||
String in = "81.2.69.160";
|
|
||||||
String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
|
|
||||||
"51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
|
|
||||||
|
|
||||||
for (int i = 0; i < locationTypes.length; i++) {
|
|
||||||
LocationType locationType = locationTypes[i];
|
|
||||||
String location = locations[i];
|
|
||||||
|
|
||||||
Transform transform = new IPAddressToLocationTransform("column", locationType);
|
|
||||||
transform.setInputSchema(schema);
|
|
||||||
|
|
||||||
Schema out = transform.transform(schema);
|
|
||||||
|
|
||||||
assertEquals(1, out.getColumnMetaData().size());
|
|
||||||
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
|
|
||||||
|
|
||||||
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
|
|
||||||
assertEquals(1, writables.size());
|
|
||||||
assertEquals(location, writables.get(0).toString());
|
|
||||||
//System.out.println(location);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,386 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.local.transforms.transform;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.transform.condition.Condition;
|
|
||||||
import org.datavec.api.transform.filter.ConditionFilter;
|
|
||||||
import org.datavec.api.transform.filter.Filter;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
|
||||||
|
|
||||||
import org.datavec.api.writable.*;
|
|
||||||
import org.datavec.python.PythonCondition;
|
|
||||||
import org.datavec.python.PythonTransform;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.junit.jupiter.api.Timeout;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import javax.annotation.concurrent.NotThreadSafe;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
import static org.datavec.api.transform.schema.Schema.Builder;
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
|
||||||
|
|
||||||
@NotThreadSafe
|
|
||||||
public class TestPythonTransformProcess {
|
|
||||||
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
public void testStringConcat() throws Exception{
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnString("col1")
|
|
||||||
.addColumnString("col2");
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnString("col3");
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build()
|
|
||||||
).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals((outputs.get(0)).toString(), "Hello ");
|
|
||||||
assertEquals((outputs.get(1)).toString(), "World!");
|
|
||||||
assertEquals((outputs.get(2)).toString(), "Hello World!");
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testMixedTypes() throws Exception {
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnInteger("col1")
|
|
||||||
.addColumnFloat("col2")
|
|
||||||
.addColumnString("col3")
|
|
||||||
.addColumnDouble("col4");
|
|
||||||
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnInteger("col5");
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.inputSchema(initialSchema)
|
|
||||||
.build() ).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(new IntWritable(10),
|
|
||||||
new FloatWritable(3.5f),
|
|
||||||
new Text("5"),
|
|
||||||
new DoubleWritable(2.0)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testNDArray() throws Exception {
|
|
||||||
long[] shape = new long[]{3, 2};
|
|
||||||
INDArray arr1 = Nd4j.rand(shape);
|
|
||||||
INDArray arr2 = Nd4j.rand(shape);
|
|
||||||
|
|
||||||
INDArray expectedOutput = arr1.add(arr2);
|
|
||||||
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnNDArray("col1", shape)
|
|
||||||
.addColumnNDArray("col2", shape);
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnNDArray("col3", shape);
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build() ).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new NDArrayWritable(arr1),
|
|
||||||
new NDArrayWritable(arr2)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
|
||||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
|
||||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testNDArray2() throws Exception {
|
|
||||||
long[] shape = new long[]{3, 2};
|
|
||||||
INDArray arr1 = Nd4j.rand(shape);
|
|
||||||
INDArray arr2 = Nd4j.rand(shape);
|
|
||||||
|
|
||||||
INDArray expectedOutput = arr1.add(arr2);
|
|
||||||
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnNDArray("col1", shape)
|
|
||||||
.addColumnNDArray("col2", shape);
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnNDArray("col3", shape);
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build() ).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new NDArrayWritable(arr1),
|
|
||||||
new NDArrayWritable(arr2)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
|
||||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
|
||||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testNDArrayMixed() throws Exception{
|
|
||||||
long[] shape = new long[]{3, 2};
|
|
||||||
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
|
|
||||||
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
|
|
||||||
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
|
|
||||||
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnNDArray("col1", shape)
|
|
||||||
.addColumnNDArray("col2", shape);
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnNDArray("col3", shape);
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
String pythonCode = "col3 = col1 + col2";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build()
|
|
||||||
).build();
|
|
||||||
|
|
||||||
List<Writable> inputs = Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new NDArrayWritable(arr1),
|
|
||||||
new NDArrayWritable(arr2)
|
|
||||||
);
|
|
||||||
|
|
||||||
List<Writable> outputs = tp.execute(inputs);
|
|
||||||
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
|
|
||||||
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
|
|
||||||
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testPythonFilter() {
|
|
||||||
Schema schema = new Builder().addColumnInteger("column").build();
|
|
||||||
|
|
||||||
Condition condition = new PythonCondition(
|
|
||||||
"f = lambda: column < 0"
|
|
||||||
);
|
|
||||||
|
|
||||||
condition.setInputSchema(schema);
|
|
||||||
|
|
||||||
Filter filter = new ConditionFilter(condition);
|
|
||||||
|
|
||||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
|
|
||||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
|
|
||||||
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
|
|
||||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
|
|
||||||
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test()
|
|
||||||
@Timeout(60000L)
|
|
||||||
public void testPythonFilterAndTransform() throws Exception {
|
|
||||||
Builder schemaBuilder = new Builder();
|
|
||||||
schemaBuilder
|
|
||||||
.addColumnInteger("col1")
|
|
||||||
.addColumnFloat("col2")
|
|
||||||
.addColumnString("col3")
|
|
||||||
.addColumnDouble("col4");
|
|
||||||
|
|
||||||
Schema initialSchema = schemaBuilder.build();
|
|
||||||
schemaBuilder.addColumnString("col6");
|
|
||||||
Schema finalSchema = schemaBuilder.build();
|
|
||||||
|
|
||||||
Condition condition = new PythonCondition(
|
|
||||||
"f = lambda: col1 < 0 and col2 > 10.0"
|
|
||||||
);
|
|
||||||
|
|
||||||
condition.setInputSchema(initialSchema);
|
|
||||||
|
|
||||||
Filter filter = new ConditionFilter(condition);
|
|
||||||
|
|
||||||
String pythonCode = "col6 = str(col1 + col2)";
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
|
|
||||||
PythonTransform.builder().code(pythonCode)
|
|
||||||
.outputSchema(finalSchema)
|
|
||||||
.build()
|
|
||||||
).filter(
|
|
||||||
filter
|
|
||||||
).build();
|
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(
|
|
||||||
Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new IntWritable(5),
|
|
||||||
new FloatWritable(3.0f),
|
|
||||||
new Text("abcd"),
|
|
||||||
new DoubleWritable(2.1))
|
|
||||||
);
|
|
||||||
inputs.add(
|
|
||||||
Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new IntWritable(-3),
|
|
||||||
new FloatWritable(3.0f),
|
|
||||||
new Text("abcd"),
|
|
||||||
new DoubleWritable(2.1))
|
|
||||||
);
|
|
||||||
inputs.add(
|
|
||||||
Arrays.asList(
|
|
||||||
(Writable)
|
|
||||||
new IntWritable(5),
|
|
||||||
new FloatWritable(11.2f),
|
|
||||||
new Text("abcd"),
|
|
||||||
new DoubleWritable(2.1))
|
|
||||||
);
|
|
||||||
|
|
||||||
LocalTransformExecutor.execute(inputs,tp);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testPythonTransformNoOutputSpecified() throws Exception {
|
|
||||||
PythonTransform pythonTransform = PythonTransform.builder()
|
|
||||||
.code("a += 2; b = 'hello world'")
|
|
||||||
.returnAllInputs(true)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
|
|
||||||
Schema inputSchema = new Builder()
|
|
||||||
.addColumnInteger("a")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
|
||||||
.transform(pythonTransform)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
|
||||||
assertEquals(3,execute.get(0).get(0).toInt());
|
|
||||||
assertEquals("hello world",execute.get(0).get(1).toString());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNumpyTransform() {
|
|
||||||
PythonTransform pythonTransform = PythonTransform.builder()
|
|
||||||
.code("a += 2; b = 'hello world'")
|
|
||||||
.returnAllInputs(true)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
|
|
||||||
Schema inputSchema = new Builder()
|
|
||||||
.addColumnNDArray("a",new long[]{1,1})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
|
||||||
.transform(pythonTransform)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
|
||||||
assertFalse(execute.isEmpty());
|
|
||||||
assertNotNull(execute.get(0));
|
|
||||||
assertNotNull(execute.get(0).get(0));
|
|
||||||
assertNotNull(execute.get(0).get(1));
|
|
||||||
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
|
|
||||||
assertEquals("hello world",execute.get(0).get(1).toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testWithSetupRun() throws Exception {
|
|
||||||
|
|
||||||
PythonTransform pythonTransform = PythonTransform.builder()
|
|
||||||
.code("five=None\n" +
|
|
||||||
"def setup():\n" +
|
|
||||||
" global five\n"+
|
|
||||||
" five = 5\n\n" +
|
|
||||||
"def run(a, b):\n" +
|
|
||||||
" c = a + b + five\n"+
|
|
||||||
" return {'c':c}\n\n")
|
|
||||||
.returnAllInputs(true)
|
|
||||||
.setupAndRun(true)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<List<Writable>> inputs = new ArrayList<>();
|
|
||||||
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
|
|
||||||
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
|
|
||||||
Schema inputSchema = new Builder()
|
|
||||||
.addColumnNDArray("a",new long[]{1,1})
|
|
||||||
.addColumnNDArray("b", new long[]{1, 1})
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
|
||||||
.transform(pythonTransform)
|
|
||||||
.build();
|
|
||||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
|
||||||
assertFalse(execute.isEmpty());
|
|
||||||
assertNotNull(execute.get(0));
|
|
||||||
assertNotNull(execute.get(0).get(0));
|
|
||||||
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -128,10 +128,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -92,6 +92,10 @@
|
||||||
<groupId>org.junit.jupiter</groupId>
|
<groupId>org.junit.jupiter</groupId>
|
||||||
<artifactId>junit-jupiter-api</artifactId>
|
<artifactId>junit-jupiter-api</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-params</artifactId>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.junit.vintage</groupId>
|
<groupId>org.junit.vintage</groupId>
|
||||||
<artifactId>junit-vintage-engine</artifactId>
|
<artifactId>junit-vintage-engine</artifactId>
|
||||||
|
@ -154,7 +158,7 @@
|
||||||
<skip>${skipTestResourceEnforcement}</skip>
|
<skip>${skipTestResourceEnforcement}</skip>
|
||||||
<rules>
|
<rules>
|
||||||
<requireActiveProfile>
|
<requireActiveProfile>
|
||||||
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles>
|
<profiles>nd4j-tests-cpu,nd4j-tests-cuda</profiles>
|
||||||
<all>false</all>
|
<all>false</all>
|
||||||
</requireActiveProfile>
|
</requireActiveProfile>
|
||||||
</rules>
|
</rules>
|
||||||
|
@ -163,23 +167,6 @@
|
||||||
</execution>
|
</execution>
|
||||||
</executions>
|
</executions>
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<configuration>
|
|
||||||
<argLine></argLine>
|
|
||||||
|
|
||||||
<!--
|
|
||||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
|
||||||
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
|
|
||||||
function correctly without this configuration.
|
|
||||||
For example, tests for custom transforms (where the custom transform is defined in the test directory)
|
|
||||||
will fail due to the custom transform not being found on the classpath.
|
|
||||||
http://maven.apache.org/surefire/maven-surefire-plugin/examples/class-loading.html
|
|
||||||
-->
|
|
||||||
<useSystemClassLoader>true</useSystemClassLoader>
|
|
||||||
<useManifestOnlyJar>false</useManifestOnlyJar>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.eclipse.m2e</groupId>
|
<groupId>org.eclipse.m2e</groupId>
|
||||||
<artifactId>lifecycle-mapping</artifactId>
|
<artifactId>lifecycle-mapping</artifactId>
|
||||||
|
@ -249,7 +236,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -266,7 +253,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -286,9 +273,6 @@
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
|
||||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
</build>
|
</build>
|
||||||
|
|
|
@ -64,7 +64,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -75,7 +75,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -56,10 +56,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -166,7 +166,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -177,7 +177,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -23,7 +23,6 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
|
||||||
|
@ -34,7 +33,6 @@ import java.util.List;
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
|
||||||
|
|
||||||
@DisplayName("Early Termination Data Set Iterator Test")
|
@DisplayName("Early Termination Data Set Iterator Test")
|
||||||
class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
|
class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
|
@ -21,19 +21,16 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
|
||||||
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import org.junit.jupiter.api.DisplayName;
|
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@DisplayName("Early Termination Multi Data Set Iterator Test")
|
@DisplayName("Early Termination Multi Data Set Iterator Test")
|
||||||
|
|
|
@ -34,7 +34,6 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -46,7 +45,6 @@ import java.util.Random;
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
import org.junit.jupiter.api.DisplayName;
|
import org.junit.jupiter.api.DisplayName;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
|
||||||
|
|
||||||
@Disabled
|
@Disabled
|
||||||
@DisplayName("Attention Layer Test")
|
@DisplayName("Attention Layer Test")
|
||||||
|
|
|
@ -105,11 +105,12 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
<plugin>
|
<plugin>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
|
<inherited>true</inherited>
|
||||||
<configuration>
|
<configuration>
|
||||||
<skip>true</skip>
|
<skip>true</skip>
|
||||||
</configuration>
|
</configuration>
|
||||||
|
@ -118,7 +119,7 @@
|
||||||
</build>
|
</build>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -56,10 +56,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -50,10 +50,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -45,10 +45,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -54,10 +54,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -112,7 +112,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -123,7 +123,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -72,10 +72,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -306,7 +306,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -317,7 +317,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -101,10 +101,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -49,10 +49,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -127,10 +127,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -102,7 +102,7 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
@ -113,7 +113,7 @@
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
|
|
|
@ -99,10 +99,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -44,10 +44,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -89,10 +89,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -88,10 +88,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -90,10 +90,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -105,10 +105,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -182,10 +182,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -77,10 +77,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -104,10 +104,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -141,10 +141,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -79,10 +79,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -426,10 +426,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
|
@ -44,10 +44,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -87,10 +87,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -117,10 +117,10 @@
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
|
@ -143,6 +143,10 @@
|
||||||
</extensions>
|
</extensions>
|
||||||
|
|
||||||
<plugins>
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-enforcer-plugin</artifactId>
|
<artifactId>maven-enforcer-plugin</artifactId>
|
||||||
|
@ -158,7 +162,7 @@
|
||||||
<skip>${skipBackendChoice}</skip>
|
<skip>${skipBackendChoice}</skip>
|
||||||
<rules>
|
<rules>
|
||||||
<requireActiveProfile>
|
<requireActiveProfile>
|
||||||
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles>
|
<profiles>nd4j-tests-cpu,nd4j-tests-cuda</profiles>
|
||||||
<all>false</all>
|
<all>false</all>
|
||||||
</requireActiveProfile>
|
</requireActiveProfile>
|
||||||
</rules>
|
</rules>
|
||||||
|
@ -227,43 +231,6 @@
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
|
||||||
<pluginManagement>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<inherited>true</inherited>
|
|
||||||
<configuration>
|
|
||||||
<!--
|
|
||||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
|
||||||
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
|
|
||||||
function correctly without this configuration.
|
|
||||||
For example, tests for custom layers (where the custom layer is defined in the test directory)
|
|
||||||
will fail due to the custom layer not being found on the classpath.
|
|
||||||
http://maven.apache.org/surefire/maven-surefire-plugin/examples/class-loading.html
|
|
||||||
-->
|
|
||||||
<useSystemClassLoader>true</useSystemClassLoader>
|
|
||||||
<useManifestOnlyJar>false</useManifestOnlyJar>
|
|
||||||
<argLine> -Dfile.encoding=UTF-8 -Xmx8g "</argLine>
|
|
||||||
<includes>
|
|
||||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
|
||||||
<include>*.java</include>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
</includes>
|
|
||||||
</configuration>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.maven.surefire</groupId>
|
|
||||||
<artifactId>surefire-junit-platform</artifactId>
|
|
||||||
<version>${maven-surefire-plugin.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.eclipse.m2e</groupId>
|
|
||||||
<artifactId>lifecycle-mapping</artifactId>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</pluginManagement>
|
|
||||||
</build>
|
</build>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
@ -290,10 +257,10 @@
|
||||||
<module>deeplearning4j-cuda</module>
|
<module>deeplearning4j-cuda</module>
|
||||||
</modules>
|
</modules>
|
||||||
</profile>
|
</profile>
|
||||||
<!-- For running unit tests with nd4j-native: "mvn clean test -P test-nd4j-native"
|
<!-- For running unit tests with nd4j-native: "mvn clean test -P nd4j-tests-cpu"
|
||||||
Note that this excludes DL4J-cuda -->
|
Note that this excludes DL4J-cuda -->
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
<activation>
|
<activation>
|
||||||
<activeByDefault>false</activeByDefault>
|
<activeByDefault>false</activeByDefault>
|
||||||
</activation>
|
</activation>
|
||||||
|
@ -311,70 +278,10 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<inherited>true</inherited>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-native</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.junit.jupiter</groupId>
|
|
||||||
<artifactId>junit-jupiter-engine</artifactId>
|
|
||||||
<version>${junit.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.junit.jupiter</groupId>
|
|
||||||
<artifactId>junit-jupiter-params</artifactId>
|
|
||||||
<version>${junit.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.maven.surefire</groupId>
|
|
||||||
<artifactId>surefire-junit-platform</artifactId>
|
|
||||||
<version>${maven-surefire-plugin.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
<configuration>
|
|
||||||
<environmentVariables>
|
|
||||||
|
|
||||||
</environmentVariables>
|
|
||||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>*.java</include>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
<include>**/Test*.java</include>
|
|
||||||
<include>**/*Test.java</include>
|
|
||||||
<include>**/*TestCase.java</include>
|
|
||||||
</includes>
|
|
||||||
<junitArtifactName>org.junit.jupiter:junit-jupiter-engine</junitArtifactName>
|
|
||||||
<systemPropertyVariables>
|
|
||||||
<org.nd4j.linalg.defaultbackend>
|
|
||||||
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
||||||
</org.nd4j.linalg.defaultbackend>
|
|
||||||
<org.nd4j.linalg.tests.backendstorun>
|
|
||||||
org.nd4j.linalg.cpu.nativecpu.CpuBackend
|
|
||||||
</org.nd4j.linalg.tests.backendstorun>
|
|
||||||
</systemPropertyVariables>
|
|
||||||
<!--
|
|
||||||
Maximum heap size was set to 8g, as a minimum required value for tests run.
|
|
||||||
Depending on a build machine, default value is not always enough.
|
|
||||||
|
|
||||||
For testing large zoo models, this may not be enough (so comment it out).
|
|
||||||
-->
|
|
||||||
<argLine></argLine>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
</profile>
|
</profile>
|
||||||
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
|
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
<id>nd4j-tests-cuda</id>
|
||||||
<activation>
|
<activation>
|
||||||
<activeByDefault>false</activeByDefault>
|
<activeByDefault>false</activeByDefault>
|
||||||
</activation>
|
</activation>
|
||||||
|
@ -392,43 +299,6 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
<!-- Default to ALL modules here, unlike nd4j-native -->
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<version>${maven-surefire-plugin.version}</version>
|
|
||||||
<configuration>
|
|
||||||
<environmentVariables>
|
|
||||||
</environmentVariables>
|
|
||||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>*.java</include>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
<include>**/Test*.java</include>
|
|
||||||
<include>**/*Test.java</include>
|
|
||||||
<include>**/*TestCase.java</include>
|
|
||||||
</includes>
|
|
||||||
<junitArtifactName>org.junit.jupiter:junit-jupiter</junitArtifactName>
|
|
||||||
<systemPropertyVariables>
|
|
||||||
<org.nd4j.linalg.defaultbackend>
|
|
||||||
org.nd4j.linalg.jcublas.JCublasBackend
|
|
||||||
</org.nd4j.linalg.defaultbackend>
|
|
||||||
<org.nd4j.linalg.tests.backendstorun>
|
|
||||||
org.nd4j.linalg.jcublas.JCublasBackend
|
|
||||||
</org.nd4j.linalg.tests.backendstorun>
|
|
||||||
</systemPropertyVariables>
|
|
||||||
<!--
|
|
||||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
|
||||||
Depending on a build machine, default value is not always enough.
|
|
||||||
-->
|
|
||||||
<argLine> -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
|
||||||
</configuration>
|
|
||||||
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -5,7 +5,7 @@ Linux
|
||||||
[INFO] Total time: 14.610 s
|
[INFO] Total time: 14.610 s
|
||||||
[INFO] Finished at: 2021-03-06T15:35:28+09:00
|
[INFO] Finished at: 2021-03-06T15:35:28+09:00
|
||||||
[INFO] ------------------------------------------------------------------------
|
[INFO] ------------------------------------------------------------------------
|
||||||
[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist.
|
[WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
|
||||||
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
|
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
|
||||||
[ERROR]
|
[ERROR]
|
||||||
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
|
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
|
||||||
|
@ -749,7 +749,7 @@ make[1]: Leaving directory '/c/Users/agibs/Documents/GitHub/eclipse-deeplearning
|
||||||
[INFO] Total time: 15.482 s
|
[INFO] Total time: 15.482 s
|
||||||
[INFO] Finished at: 2021-03-06T15:27:35+09:00
|
[INFO] Finished at: 2021-03-06T15:27:35+09:00
|
||||||
[INFO] ------------------------------------------------------------------------
|
[INFO] ------------------------------------------------------------------------
|
||||||
[WARNING] The requested profile "test-nd4j-native" could not be activated because it does not exist.
|
[WARNING] The requested profile "nd4j-tests-cpu" could not be activated because it does not exist.
|
||||||
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
|
[ERROR] Failed to execute goal org.bytedeco:javacpp:1.5.4:build (libnd4j-test-run) on project libnd4j: Execution libnd4j-test-run of goal org.bytedeco:javacpp:1.5.4:build failed: Process exited with an error: 127 -> [Help 1]
|
||||||
[ERROR]
|
[ERROR]
|
||||||
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
|
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
Const,in_0
|
||||||
|
Const,while/Const
|
||||||
|
Const,while/add/y
|
||||||
|
Identity,in_0/read
|
||||||
|
Enter,while/Enter
|
||||||
|
Enter,while/Enter_1
|
||||||
|
Merge,while/Merge
|
||||||
|
Merge,while/Merge_1
|
||||||
|
Less,while/Less
|
||||||
|
LoopCond,while/LoopCond
|
||||||
|
Switch,while/Switch
|
||||||
|
Switch,while/Switch_1
|
||||||
|
Identity,while/Identity
|
||||||
|
Exit,while/Exit
|
||||||
|
Identity,while/Identity_1
|
||||||
|
Exit,while/Exit_1
|
||||||
|
Add,while/add
|
||||||
|
NextIteration,while/NextIteration_1
|
||||||
|
NextIteration,while/NextIteration
|
|
@ -0,0 +1,16 @@
|
||||||
|
Identity,in_0/read
|
||||||
|
Enter,while/Enter
|
||||||
|
Enter,while/Enter_1
|
||||||
|
Merge,while/Merge
|
||||||
|
Merge,while/Merge_1
|
||||||
|
Less,while/Less
|
||||||
|
LoopCond,while/LoopCond
|
||||||
|
Switch,while/Switch
|
||||||
|
Switch,while/Switch_1
|
||||||
|
Identity,while/Identity
|
||||||
|
Exit,while/Exit
|
||||||
|
Identity,while/Identity_1
|
||||||
|
Exit,while/Exit_1
|
||||||
|
Add,while/add
|
||||||
|
NextIteration,while/NextIteration_1
|
||||||
|
NextIteration,while/NextIteration
|
|
@ -0,0 +1,19 @@
|
||||||
|
in_0
|
||||||
|
while/Const
|
||||||
|
while/add/y
|
||||||
|
in_0/read
|
||||||
|
while/Enter
|
||||||
|
while/Enter_1
|
||||||
|
while/Merge
|
||||||
|
while/Merge_1
|
||||||
|
while/Less
|
||||||
|
while/LoopCond
|
||||||
|
while/Switch
|
||||||
|
while/Switch_1
|
||||||
|
while/Identity
|
||||||
|
while/Exit
|
||||||
|
while/Identity_1
|
||||||
|
while/Exit_1
|
||||||
|
while/add
|
||||||
|
while/NextIteration_1
|
||||||
|
while/NextIteration
|
|
@ -303,7 +303,7 @@
|
||||||
|
|
||||||
For testing large zoo models, this may not be enough (so comment it out).
|
For testing large zoo models, this may not be enough (so comment it out).
|
||||||
-->
|
-->
|
||||||
<argLine>-Dfile.encoding=UTF-8 "</argLine>
|
<argLine>-Dfile.encoding=UTF-8 </argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
|
|
@ -27,6 +27,7 @@ import java.util.List;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.TestInfo;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -482,7 +483,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConv3d(Nd4jBackend backend) {
|
public void testConv3d(Nd4jBackend backend, TestInfo testInfo) {
|
||||||
//Pooling3d, Conv3D, batch norm
|
//Pooling3d, Conv3D, batch norm
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
@ -573,7 +574,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
tc.testName(msg);
|
tc.testName(msg);
|
||||||
String error = OpValidation.validate(tc);
|
String error = OpValidation.validate(tc);
|
||||||
if (error != null) {
|
if (error != null) {
|
||||||
failed.add(name);
|
failed.add(testInfo.getTestMethod().get().getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1353,7 +1354,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err, err);
|
assertNull(err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
|
public void exceptionThrown_WhenConv1DConfigInvalid(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
int nIn = 3;
|
int nIn = 3;
|
||||||
|
@ -1382,7 +1384,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
|
public void exceptionThrown_WhenConv2DConfigInvalid(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
@ -1405,7 +1408,8 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
|
public void exceptionThrown_WhenConf3DInvalid(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -22,6 +22,7 @@ package org.nd4j.autodiff.opvalidation;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -664,6 +665,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void testMmulGradientManual(Nd4jBackend backend) {
|
public void testMmulGradientManual(Nd4jBackend backend) {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray sumInput = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
|
|
|
@ -69,7 +69,7 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
public void tearDown(Nd4jBackend backend) {
|
public void tearDown() {
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ import lombok.val;
|
||||||
import org.apache.commons.math3.linear.LUDecomposition;
|
import org.apache.commons.math3.linear.LUDecomposition;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.TestInfo;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.OpValidationSuite;
|
import org.nd4j.OpValidationSuite;
|
||||||
|
@ -83,7 +84,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConcat(Nd4jBackend backend) {
|
public void testConcat(Nd4jBackend backend, TestInfo testInfo) {
|
||||||
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
|
// int[] concatDim = new int[]{0,0,0,1,1,1,2,2,2};
|
||||||
int[] concatDim = new int[]{0, 0, 0};
|
int[] concatDim = new int[]{0, 0, 0};
|
||||||
List<List<int[]>> origShapes = new ArrayList<>();
|
List<List<int[]>> origShapes = new ArrayList<>();
|
||||||
|
@ -115,7 +116,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
String error = OpValidation.validate(tc);
|
String error = OpValidation.validate(tc);
|
||||||
if(error != null){
|
if(error != null){
|
||||||
failed.add(name);
|
failed.add(testInfo.getTestMethod().get().getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -285,7 +286,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSqueezeGradient(Nd4jBackend backend) {
|
public void testSqueezeGradient(Nd4jBackend backend,TestInfo testInfo) {
|
||||||
val origShape = new long[]{3, 4, 5};
|
val origShape = new long[]{3, 4, 5};
|
||||||
|
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
@ -339,7 +340,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
String error = OpValidation.validate(tc, true);
|
String error = OpValidation.validate(tc, true);
|
||||||
if(error != null){
|
if(error != null){
|
||||||
failed.add(name);
|
failed.add(testInfo.getTestMethod().get().getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -580,8 +581,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
return Long.MAX_VALUE;
|
return Long.MAX_VALUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
public void testStack(Nd4jBackend backend) {
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
public void testStack(Nd4jBackend backend,TestInfo testInfo) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
List<String> failed = new ArrayList<>();
|
List<String> failed = new ArrayList<>();
|
||||||
|
@ -661,7 +663,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
String error = OpValidation.validate(tc);
|
String error = OpValidation.validate(tc);
|
||||||
if(error != null){
|
if(error != null){
|
||||||
failed.add(name);
|
failed.add(testInfo.getTestMethod().get().getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,6 +72,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering(){
|
public char ordering(){
|
||||||
|
@ -82,7 +84,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testBasic(Nd4jBackend backend) throws Exception {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
|
INDArray arr = Nd4j.linspace(1,12,12).reshape(3,4);
|
||||||
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
|
SDVariable in = sd.placeHolder("in", arr.dataType(), arr.shape() );
|
||||||
|
@ -138,7 +140,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testSimple(Nd4jBackend backend) throws Exception {
|
||||||
for( int i = 0; i < 10; i++ ) {
|
for( int i = 0; i < 10; i++ ) {
|
||||||
for(boolean execFirst : new boolean[]{false, true}) {
|
for(boolean execFirst : new boolean[]{false, true}) {
|
||||||
log.info("Starting test: i={}, execFirst={}", i, execFirst);
|
log.info("Starting test: i={}, execFirst={}", i, execFirst);
|
||||||
|
@ -268,7 +270,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTrainingSerde(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testTrainingSerde(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
//Ensure 2 things:
|
//Ensure 2 things:
|
||||||
//1. Training config is serialized/deserialized correctly
|
//1. Training config is serialized/deserialized correctly
|
||||||
|
|
|
@ -109,7 +109,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void before(Nd4jBackend backend) {
|
public void before() {
|
||||||
Nd4j.create(1);
|
Nd4j.create(1);
|
||||||
initialType = Nd4j.dataType();
|
initialType = Nd4j.dataType();
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
public void after(Nd4jBackend backend) {
|
public void after() {
|
||||||
Nd4j.setDataType(initialType);
|
Nd4j.setDataType(initialType);
|
||||||
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
||||||
|
|
|
@ -21,9 +21,6 @@
|
||||||
package org.nd4j.autodiff.samediff.listeners;
|
package org.nd4j.autodiff.samediff.listeners;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -47,11 +44,11 @@ import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -97,7 +94,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointEveryEpoch(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointEveryEpoch(Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
|
|
||||||
SameDiff sd = getModel();
|
SameDiff sd = getModel();
|
||||||
|
@ -132,7 +129,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointEvery5Iter(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointEvery5Iter(Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
|
|
||||||
SameDiff sd = getModel();
|
SameDiff sd = getModel();
|
||||||
|
@ -172,7 +169,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointListenerEveryTimeUnit(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointListenerEveryTimeUnit(Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
SameDiff sd = getModel();
|
SameDiff sd = getModel();
|
||||||
|
|
||||||
|
@ -217,7 +214,7 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCheckpointListenerKeepLast3AndEvery3(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testCheckpointListenerKeepLast3AndEvery3(Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
SameDiff sd = getModel();
|
SameDiff sd = getModel();
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ package org.nd4j.autodiff.samediff.listeners;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
@ -49,6 +50,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
|
|
||||||
public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
|
public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
|
@ -59,7 +61,8 @@ public class ProfilingListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testProfilingListenerSimple(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
@Disabled
|
||||||
|
public void testProfilingListenerSimple(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
|
||||||
|
|
|
@ -64,6 +64,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering(){
|
public char ordering(){
|
||||||
|
@ -81,7 +82,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimple(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
|
public void testSimple(Nd4jBackend backend) throws IOException {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
|
SDVariable v = sd.var("variable", DataType.DOUBLE, 3, 4);
|
||||||
SDVariable sum = v.sum();
|
SDVariable sum = v.sum();
|
||||||
|
@ -186,7 +187,7 @@ public class FileReadWriteTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNullBinLabels(@TempDir Path testDir,Nd4jBackend backend) throws Exception{
|
public void testNullBinLabels(Nd4jBackend backend) throws Exception{
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
File f = new File(dir, "temp.bin");
|
File f = new File(dir, "temp.bin");
|
||||||
LogFileWriter w = new LogFileWriter(f);
|
LogFileWriter w = new LogFileWriter(f);
|
||||||
|
|
|
@ -56,6 +56,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
public class UIListenerTest extends BaseNd4jTestWithBackends {
|
public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
@ -65,7 +67,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUIListenerBasic(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testUIListenerBasic(Nd4jBackend backend) throws Exception {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
@ -102,7 +104,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUIListenerContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testUIListenerContinue(Nd4jBackend backend) throws Exception {
|
||||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
SameDiff sd1 = getSimpleNet();
|
SameDiff sd1 = getSimpleNet();
|
||||||
|
@ -194,7 +196,7 @@ public class UIListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUIListenerBadContinue(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testUIListenerBadContinue(Nd4jBackend backend) throws Exception {
|
||||||
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
IrisDataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
SameDiff sd1 = getSimpleNet();
|
SameDiff sd1 = getSimpleNet();
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ package org.nd4j.evaluation;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -48,6 +49,7 @@ public class NewInstanceTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void testNewInstances(Nd4jBackend backend) {
|
public void testNewInstances(Nd4jBackend backend) {
|
||||||
boolean print = true;
|
boolean print = true;
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
package org.nd4j.evaluation;
|
package org.nd4j.evaluation;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.evaluation.classification.ROC;
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
|
@ -50,6 +50,7 @@ public class ROCBinaryTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void testROCBinary(Nd4jBackend backend) {
|
public void testROCBinary(Nd4jBackend backend) {
|
||||||
//Compare ROCBinary to ROC class
|
//Compare ROCBinary to ROC class
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.nd4j.evaluation;
|
package org.nd4j.evaluation;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -971,7 +972,8 @@ public class ROCTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRocBinaryMerge(){
|
@Disabled
|
||||||
|
public void testRocBinaryMerge(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
ROCBinary roc = new ROCBinary();
|
ROCBinary roc = new ROCBinary();
|
||||||
|
|
|
@ -50,7 +50,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEvalParameters(Nd4jBackend backend) {
|
public void testEvalParameters(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
int specCols = 5;
|
int specCols = 5;
|
||||||
|
|
|
@ -152,7 +152,7 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
public void maskWhenMerge(Nd4jBackend backend) {
|
public void maskWhenMerge(Nd4jBackend backend) {
|
||||||
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
|
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
|
||||||
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
|
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
|
||||||
List<DataSet> dataSetList = new ArrayList<DataSet>();
|
List<DataSet> dataSetList = new ArrayList<>();
|
||||||
dataSetList.add(dsA);
|
dataSetList.add(dsA);
|
||||||
dataSetList.add(dsB);
|
dataSetList.add(dsB);
|
||||||
DataSet fullDataSet = DataSet.merge(dataSetList);
|
DataSet fullDataSet = DataSet.merge(dataSetList);
|
||||||
|
@ -175,7 +175,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
// System.out.println(b);
|
// System.out.println(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
//broken at a threshold
|
//broken at a threshold
|
||||||
public void testArgMax(Nd4jBackend backend) {
|
public void testArgMax(Nd4jBackend backend) {
|
||||||
int max = 63;
|
int max = 63;
|
||||||
|
@ -263,7 +264,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
// log.info("p50: {}; avg: {};", times.get(times.size() / 2), time);
|
// log.info("p50: {}; avg: {};", times.get(times.size() / 2), time);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void checkIllegalElementOps(Nd4jBackend backend) {
|
public void checkIllegalElementOps(Nd4jBackend backend) {
|
||||||
assertThrows(Exception.class,() -> {
|
assertThrows(Exception.class,() -> {
|
||||||
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
|
INDArray A = Nd4j.linspace(1, 20, 20).reshape(4, 5);
|
||||||
|
|
|
@ -245,7 +245,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, false);
|
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, false);
|
||||||
assertEquals(sorted[1], sorted2);
|
assertEquals(sorted[1], sorted2);
|
||||||
INDArray shouldIndex = Nd4j.create(new double[] {1, 1, 0, 0}, new long[] {2, 2});
|
INDArray shouldIndex = Nd4j.create(new double[] {1, 1, 0, 0}, new long[] {2, 2});
|
||||||
assertEquals(shouldIndex, sorted[0],getFailureMessage());
|
assertEquals(shouldIndex, sorted[0],getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -266,7 +266,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, true);
|
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, true);
|
||||||
assertEquals(sorted[1], sorted2);
|
assertEquals(sorted[1], sorted2);
|
||||||
INDArray shouldIndex = Nd4j.create(new double[] {0, 0, 1, 1}, new long[] {2, 2});
|
INDArray shouldIndex = Nd4j.create(new double[] {0, 0, 1, 1}, new long[] {2, 2});
|
||||||
assertEquals(shouldIndex, sorted[0],getFailureMessage());
|
assertEquals(shouldIndex, sorted[0],getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -328,13 +328,13 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
public void testDivide(Nd4jBackend backend) {
|
public void testDivide(Nd4jBackend backend) {
|
||||||
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
|
INDArray two = Nd4j.create(new float[] {2, 2, 2, 2});
|
||||||
INDArray div = two.div(two);
|
INDArray div = two.div(two);
|
||||||
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage());
|
assertEquals( Nd4j.ones(DataType.FLOAT, 4), div,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray half = Nd4j.create(new float[] {0.5f, 0.5f, 0.5f, 0.5f}, new long[] {2, 2});
|
INDArray half = Nd4j.create(new float[] {0.5f, 0.5f, 0.5f, 0.5f}, new long[] {2, 2});
|
||||||
INDArray divi = Nd4j.create(new float[] {0.3f, 0.6f, 0.9f, 0.1f}, new long[] {2, 2});
|
INDArray divi = Nd4j.create(new float[] {0.3f, 0.6f, 0.9f, 0.1f}, new long[] {2, 2});
|
||||||
INDArray assertion = Nd4j.create(new float[] {1.6666666f, 0.8333333f, 0.5555556f, 5}, new long[] {2, 2});
|
INDArray assertion = Nd4j.create(new float[] {1.6666666f, 0.8333333f, 0.5555556f, 5}, new long[] {2, 2});
|
||||||
INDArray result = half.div(divi);
|
INDArray result = half.div(divi);
|
||||||
assertEquals( assertion, result,getFailureMessage());
|
assertEquals( assertion, result,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -344,7 +344,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
|
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
|
||||||
INDArray sigmoid = Transforms.sigmoid(n, false);
|
INDArray sigmoid = Transforms.sigmoid(n, false);
|
||||||
assertEquals( assertion, sigmoid,getFailureMessage());
|
assertEquals( assertion, sigmoid,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -354,7 +354,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
||||||
INDArray neg = Transforms.neg(n);
|
INDArray neg = Transforms.neg(n);
|
||||||
assertEquals(assertion, neg,getFailureMessage());
|
assertEquals(assertion, neg,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,12 +365,12 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
double sim = Transforms.cosineSim(vec1, vec2);
|
double sim = Transforms.cosineSim(vec1, vec2);
|
||||||
assertEquals(1, sim, 1e-1,getFailureMessage());
|
assertEquals(1, sim, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
|
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
|
||||||
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
|
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
|
||||||
sim = Transforms.cosineSim(vec3, vec4);
|
sim = Transforms.cosineSim(vec3, vec4);
|
||||||
assertEquals(0.98, sim, 1e-1,getFailureMessage());
|
assertEquals(0.98, sim, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -621,7 +621,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray innerProduct = n.mmul(transposed);
|
INDArray innerProduct = n.mmul(transposed);
|
||||||
|
|
||||||
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
|
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
|
||||||
assertEquals(scalar, innerProduct,getFailureMessage());
|
assertEquals(scalar, innerProduct,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -678,7 +678,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray five = Nd4j.ones(5);
|
INDArray five = Nd4j.ones(5);
|
||||||
five.addi(five.dup());
|
five.addi(five.dup());
|
||||||
INDArray twos = Nd4j.valueArrayOf(5, 2);
|
INDArray twos = Nd4j.valueArrayOf(5, 2);
|
||||||
assertEquals(twos, five,getFailureMessage());
|
assertEquals(twos, five,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -692,7 +692,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
||||||
|
|
||||||
INDArray test = arr.mmul(arr.transpose());
|
INDArray test = arr.mmul(arr.transpose());
|
||||||
assertEquals(assertion, test,getFailureMessage());
|
assertEquals(assertion, test,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -704,7 +704,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.exec(new PrintVariable(newSlice));
|
Nd4j.exec(new PrintVariable(newSlice));
|
||||||
log.info("Slice: {}", newSlice);
|
log.info("Slice: {}", newSlice);
|
||||||
n.putSlice(0, newSlice);
|
n.putSlice(0, newSlice);
|
||||||
assertEquals( newSlice, n.slice(0),getFailureMessage());
|
assertEquals( newSlice, n.slice(0),getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -713,7 +713,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
|
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
|
||||||
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
|
INDArray linear = Nd4j.create(DataType.DOUBLE, 1, 4);
|
||||||
linear.putScalar(new long[] {0, 1}, 1);
|
linear.putScalar(new long[] {0, 1}, 1);
|
||||||
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage());
|
assertEquals(linear.getDouble(0, 1), 1, 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1059,7 +1059,7 @@ public class NDArrayTestsFortran extends BaseNd4jTestWithBackends {
|
||||||
INDArray nClone = n1.add(n2);
|
INDArray nClone = n1.add(n2);
|
||||||
assertEquals(Nd4j.scalar(3), nClone);
|
assertEquals(Nd4j.scalar(3), nClone);
|
||||||
INDArray n1PlusN2 = n1.add(n2);
|
INDArray n1PlusN2 = n1.add(n2);
|
||||||
assertFalse(n1PlusN2.equals(n1),getFailureMessage());
|
assertFalse(n1PlusN2.equals(n1),getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray n3 = Nd4j.scalar(3.0);
|
INDArray n3 = Nd4j.scalar(3.0);
|
||||||
INDArray n4 = Nd4j.scalar(4.0);
|
INDArray n4 = Nd4j.scalar(4.0);
|
||||||
|
|
|
@ -156,7 +156,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
DataType initialType = Nd4j.dataType();
|
DataType initialType = Nd4j.dataType();
|
||||||
Level1 l1 = Nd4j.getBlasWrapper().level1();
|
Level1 l1 = Nd4j.getBlasWrapper().level1();
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
|
@ -255,7 +255,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSerialization(@TempDir Path testDir) throws Exception {
|
public void testSerialization(Nd4jBackend backend) throws Exception {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray arr = Nd4j.rand(1, 20);
|
INDArray arr = Nd4j.rand(1, 20);
|
||||||
|
|
||||||
|
@ -339,7 +339,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(assertion,shapeTest);
|
assertArrayEquals(assertion,shapeTest);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled //temporary till libnd4j implements general broadcasting
|
@Disabled //temporary till libnd4j implements general broadcasting
|
||||||
public void testAutoBroadcastAdd(Nd4jBackend backend) {
|
public void testAutoBroadcastAdd(Nd4jBackend backend) {
|
||||||
INDArray left = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,1,2,1);
|
INDArray left = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,1,2,1);
|
||||||
|
@ -370,9 +371,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
n.divi(Nd4j.scalar(1.0d));
|
n.divi(Nd4j.scalar(1.0d));
|
||||||
|
|
||||||
n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
|
n = Nd4j.create(Nd4j.ones(27).data(), new long[] {3, 3, 3});
|
||||||
assertEquals(27, n.sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(27, n.sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
INDArray a = n.slice(2);
|
INDArray a = n.slice(2);
|
||||||
assertEquals( true, Arrays.equals(new long[] {3, 3}, a.shape()),getFailureMessage());
|
assertEquals( true, Arrays.equals(new long[] {3, 3}, a.shape()),getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -478,12 +479,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
||||||
|
|
||||||
INDArray test = arr.mmul(arr.transpose());
|
INDArray test = arr.mmul(arr.transpose());
|
||||||
assertEquals(assertion, test,getFailureMessage());
|
assertEquals(assertion, test,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void testMmulOp() throws Exception {
|
public void testMmulOp(Nd4jBackend backend) throws Exception {
|
||||||
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
||||||
INDArray z = Nd4j.create(2, 2);
|
INDArray z = Nd4j.create(2, 2);
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
||||||
|
@ -494,7 +496,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
DynamicCustomOp op = new Mmul(arr, arr, z, mMulTranspose);
|
DynamicCustomOp op = new Mmul(arr, arr, z, mMulTranspose);
|
||||||
Nd4j.getExecutioner().execAndReturn(op);
|
Nd4j.getExecutioner().execAndReturn(op);
|
||||||
|
|
||||||
assertEquals(assertion, z,getFailureMessage());
|
assertEquals(assertion, z,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -505,7 +507,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray row1 = oneThroughFour.getRow(1).dup();
|
INDArray row1 = oneThroughFour.getRow(1).dup();
|
||||||
oneThroughFour.subiRowVector(row1);
|
oneThroughFour.subiRowVector(row1);
|
||||||
INDArray result = Nd4j.create(new double[] {-2, -2, 0, 0}, new long[] {2, 2});
|
INDArray result = Nd4j.create(new double[] {-2, -2, 0, 0}, new long[] {2, 2});
|
||||||
assertEquals(result, oneThroughFour,getFailureMessage());
|
assertEquals(result, oneThroughFour,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1093,7 +1095,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(expAllOnes.all());
|
assertTrue(expAllOnes.all());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void testSumAlongDim1sEdgeCases(Nd4jBackend backend) {
|
public void testSumAlongDim1sEdgeCases(Nd4jBackend backend) {
|
||||||
val shapes = new long[][] {
|
val shapes = new long[][] {
|
||||||
|
@ -1227,7 +1230,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray row1 = oneThroughFour.getRow(1);
|
INDArray row1 = oneThroughFour.getRow(1);
|
||||||
row1.addi(1);
|
row1.addi(1);
|
||||||
INDArray result = Nd4j.create(new double[] {1, 2, 4, 5}, new long[] {2, 2});
|
INDArray result = Nd4j.create(new double[] {1, 2, 4, 5}, new long[] {2, 2});
|
||||||
assertEquals(result, oneThroughFour,getFailureMessage());
|
assertEquals(result, oneThroughFour,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1241,8 +1244,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray linear = test.reshape(-1);
|
INDArray linear = test.reshape(-1);
|
||||||
linear.putScalar(2, 6);
|
linear.putScalar(2, 6);
|
||||||
linear.putScalar(3, 7);
|
linear.putScalar(3, 7);
|
||||||
assertEquals(6, linear.getFloat(2), 1e-1,getFailureMessage());
|
assertEquals(6, linear.getFloat(2), 1e-1,getFailureMessage(backend));
|
||||||
assertEquals(7, linear.getFloat(3), 1e-1,getFailureMessage());
|
assertEquals(7, linear.getFloat(3), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1609,7 +1612,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
||||||
INDArray neg = Transforms.neg(n);
|
INDArray neg = Transforms.neg(n);
|
||||||
assertEquals(assertion, neg,getFailureMessage());
|
assertEquals(assertion, neg,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1622,13 +1625,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
double assertion = 5.47722557505;
|
double assertion = 5.47722557505;
|
||||||
double norm3 = n.norm2Number().doubleValue();
|
double norm3 = n.norm2Number().doubleValue();
|
||||||
assertEquals(assertion, norm3, 1e-1,getFailureMessage());
|
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||||
INDArray row1 = row.getRow(1);
|
INDArray row1 = row.getRow(1);
|
||||||
double norm2 = row1.norm2Number().doubleValue();
|
double norm2 = row1.norm2Number().doubleValue();
|
||||||
double assertion2 = 5.0f;
|
double assertion2 = 5.0f;
|
||||||
assertEquals(assertion2, norm2, 1e-1,getFailureMessage());
|
assertEquals(assertion2, norm2, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
Nd4j.setDataType(initialType);
|
Nd4j.setDataType(initialType);
|
||||||
}
|
}
|
||||||
|
@ -1640,14 +1643,14 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
float assertion = 5.47722557505f;
|
float assertion = 5.47722557505f;
|
||||||
float norm3 = n.norm2Number().floatValue();
|
float norm3 = n.norm2Number().floatValue();
|
||||||
assertEquals(assertion, norm3, 1e-1,getFailureMessage());
|
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
INDArray row = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
|
INDArray row = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||||
INDArray row1 = row.getRow(1);
|
INDArray row1 = row.getRow(1);
|
||||||
float norm2 = row1.norm2Number().floatValue();
|
float norm2 = row1.norm2Number().floatValue();
|
||||||
float assertion2 = 5.0f;
|
float assertion2 = 5.0f;
|
||||||
assertEquals(assertion2, norm2, 1e-1,getFailureMessage());
|
assertEquals(assertion2, norm2, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1659,7 +1662,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
||||||
double sim = Transforms.cosineSim(vec1, vec2);
|
double sim = Transforms.cosineSim(vec1, vec2);
|
||||||
assertEquals(1, sim, 1e-1,getFailureMessage());
|
assertEquals(1, sim, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
|
INDArray vec3 = Nd4j.create(new float[] {0.2f, 0.3f, 0.4f, 0.5f});
|
||||||
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
|
INDArray vec4 = Nd4j.create(new float[] {0.6f, 0.7f, 0.8f, 0.9f});
|
||||||
|
@ -1675,14 +1678,14 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
double assertion = 2;
|
double assertion = 2;
|
||||||
INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8});
|
INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8});
|
||||||
INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer);
|
INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer);
|
||||||
assertEquals(answer, scal,getFailureMessage());
|
assertEquals(answer, scal,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
||||||
INDArray row1 = row.getRow(1);
|
INDArray row1 = row.getRow(1);
|
||||||
double assertion2 = 5.0;
|
double assertion2 = 5.0;
|
||||||
INDArray answer2 = Nd4j.create(new double[] {15, 20});
|
INDArray answer2 = Nd4j.create(new double[] {15, 20});
|
||||||
INDArray scal2 = Nd4j.getBlasWrapper().scal(assertion2, row1);
|
INDArray scal2 = Nd4j.getBlasWrapper().scal(assertion2, row1);
|
||||||
assertEquals(answer2, scal2,getFailureMessage());
|
assertEquals(answer2, scal2,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2076,17 +2079,17 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray innerProduct = n.mmul(transposed);
|
INDArray innerProduct = n.mmul(transposed);
|
||||||
|
|
||||||
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
|
INDArray scalar = Nd4j.scalar(385.0).reshape(1,1);
|
||||||
assertEquals(scalar, innerProduct,getFailureMessage());
|
assertEquals(scalar, innerProduct,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray outerProduct = transposed.mmul(n);
|
INDArray outerProduct = transposed.mmul(n);
|
||||||
assertEquals(true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape()),getFailureMessage());
|
assertEquals(true, Shape.shapeEquals(new long[] {10, 10}, outerProduct.shape()),getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
INDArray three = Nd4j.create(new double[] {3, 4});
|
INDArray three = Nd4j.create(new double[] {3, 4});
|
||||||
INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2});
|
INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2});
|
||||||
INDArray sliceRow = test.slice(0).getRow(1);
|
INDArray sliceRow = test.slice(0).getRow(1);
|
||||||
assertEquals(three, sliceRow,getFailureMessage());
|
assertEquals(three, sliceRow,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray twoSix = Nd4j.create(new double[] {2, 6}, new long[] {2, 1});
|
INDArray twoSix = Nd4j.create(new double[] {2, 6}, new long[] {2, 1});
|
||||||
INDArray threeTwoSix = three.mmul(twoSix);
|
INDArray threeTwoSix = three.mmul(twoSix);
|
||||||
|
@ -2114,7 +2117,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
INDArray k1 = n1.transpose();
|
INDArray k1 = n1.transpose();
|
||||||
|
|
||||||
INDArray testVectorVector = k1.mmul(n1);
|
INDArray testVectorVector = k1.mmul(n1);
|
||||||
assertEquals(vectorVector, testVectorVector,getFailureMessage());
|
assertEquals(vectorVector, testVectorVector,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -2204,7 +2207,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(linear.getDouble(0, 1), 1, 1e-1);
|
assertEquals(linear.getDouble(0, 1), 1, 1e-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSize(Nd4jBackend backend) {
|
public void testSize(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
INDArray arr = Nd4j.create(4, 5);
|
INDArray arr = Nd4j.create(4, 5);
|
||||||
|
@ -2357,7 +2361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void testTensorDot(Nd4jBackend backend) {
|
public void testTensorDot(Nd4jBackend backend) {
|
||||||
INDArray oneThroughSixty = Nd4j.arange(60).reshape(3, 4, 5).castTo(DataType.DOUBLE);
|
INDArray oneThroughSixty = Nd4j.arange(60).reshape(3, 4, 5).castTo(DataType.DOUBLE);
|
||||||
|
@ -3051,10 +3056,10 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
public void testMeans(Nd4jBackend backend) {
|
public void testMeans(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
INDArray mean1 = a.mean(1);
|
INDArray mean1 = a.mean(1);
|
||||||
assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage());
|
assertEquals(Nd4j.create(new double[] {1.5, 3.5}), mean1,getFailureMessage(backend));
|
||||||
assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage());
|
assertEquals(Nd4j.create(new double[] {2, 3}), a.mean(0),getFailureMessage(backend));
|
||||||
assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(2.5, Nd4j.linspace(1, 4, 4, DataType.DOUBLE).meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(2.5, a.meanNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3063,9 +3068,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSums(Nd4jBackend backend) {
|
public void testSums(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage());
|
assertEquals(Nd4j.create(new double[] {3, 7}), a.sum(1),getFailureMessage(backend));
|
||||||
assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage());
|
assertEquals(Nd4j.create(new double[] {4, 6}), a.sum(0),getFailureMessage(backend));
|
||||||
assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(10, a.sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -3438,7 +3443,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void largeInstantiation(Nd4jBackend backend) {
|
public void largeInstantiation(Nd4jBackend backend) {
|
||||||
Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024 - 1)); // Still works; this can even be called as often as I want, allowing me even to spill over on disk
|
Nd4j.ones((1024 * 1024 * 511) + (1024 * 1024 - 1)); // Still works; this can even be called as often as I want, allowing me even to spill over on disk
|
||||||
|
@ -3487,7 +3493,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(cSum, fSum); //Expect: 4,6. Getting [4, 4] for f order
|
assertEquals(cSum, fSum); //Expect: 4,6. Getting [4, 4] for f order
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled //not relevant anymore
|
@Disabled //not relevant anymore
|
||||||
public void testAssignMixedC(Nd4jBackend backend) {
|
public void testAssignMixedC(Nd4jBackend backend) {
|
||||||
int[] shape1 = {3, 2, 2, 2, 2, 2};
|
int[] shape1 = {3, 2, 2, 2, 2, 2};
|
||||||
|
@ -3787,7 +3794,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(assertion, result);
|
assertEquals(assertion, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsValidation1(Nd4jBackend backend) {
|
public void testPullRowsValidation1(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2});
|
Nd4j.pullRows(Nd4j.create(10, 10), 2, new int[] {0, 1, 2});
|
||||||
|
@ -3795,7 +3803,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsValidation2(Nd4jBackend backend) {
|
public void testPullRowsValidation2(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2});
|
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, -1, 2});
|
||||||
|
@ -3803,7 +3812,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsValidation3(Nd4jBackend backend) {
|
public void testPullRowsValidation3(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10});
|
Nd4j.pullRows(Nd4j.create(10, 10), 1, new int[] {0, 1, 10});
|
||||||
|
@ -3811,7 +3821,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsValidation4(Nd4jBackend backend) {
|
public void testPullRowsValidation4(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3});
|
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2, 3});
|
||||||
|
@ -3819,7 +3830,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsValidation5(Nd4jBackend backend) {
|
public void testPullRowsValidation5(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e');
|
Nd4j.pullRows(Nd4j.create(3, 10), 1, new int[] {0, 1, 2}, 'e');
|
||||||
|
@ -4975,7 +4987,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTadReduce3_5(Nd4jBackend backend) {
|
public void testTadReduce3_5(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
INDArray initial = Nd4j.create(5, 10);
|
INDArray initial = Nd4j.create(5, 10);
|
||||||
|
@ -6004,7 +6017,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void testLogExpSum1(Nd4jBackend backend) {
|
public void testLogExpSum1(Nd4jBackend backend) {
|
||||||
INDArray matrix = Nd4j.create(3, 3);
|
INDArray matrix = Nd4j.create(3, 3);
|
||||||
|
@ -6019,7 +6033,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Disabled
|
||||||
public void testLogExpSum2(Nd4jBackend backend) {
|
public void testLogExpSum2(Nd4jBackend backend) {
|
||||||
INDArray row = Nd4j.create(new double[]{1, 2, 3});
|
INDArray row = Nd4j.create(new double[]{1, 2, 3});
|
||||||
|
@ -6246,7 +6261,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReshapeFailure(Nd4jBackend backend) {
|
public void testReshapeFailure(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2);
|
val a = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2,2);
|
||||||
|
@ -6345,7 +6361,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(new long[]{3, 2}, newShape.shape());
|
assertArrayEquals(new long[]{3, 2}, newShape.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTranspose1(Nd4jBackend backend) {
|
public void testTranspose1(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
|
val vector = Nd4j.createFromArray(new float[]{1, 2, 3, 4, 5, 6});
|
||||||
|
@ -6360,7 +6377,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTranspose2(Nd4jBackend backend) {
|
public void testTranspose2(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
val scalar = Nd4j.scalar(2.f);
|
val scalar = Nd4j.scalar(2.f);
|
||||||
|
@ -6375,7 +6393,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
//@Disabled
|
//@Disabled
|
||||||
public void testMatmul_128by256(Nd4jBackend backend) {
|
public void testMatmul_128by256(Nd4jBackend backend) {
|
||||||
val mA = Nd4j.create(128, 156).assign(1.0f);
|
val mA = Nd4j.create(128, 156).assign(1.0f);
|
||||||
|
@ -6647,7 +6666,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp1, out1);
|
assertEquals(exp1, out1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBadReduce3Call(Nd4jBackend backend) {
|
public void testBadReduce3Call(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
val x = Nd4j.create(400,20);
|
val x = Nd4j.create(400,20);
|
||||||
|
@ -7392,7 +7412,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(ez, z);
|
assertEquals(ez, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBroadcastInvalid() {
|
public void testBroadcastInvalid() {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
INDArray arr1 = Nd4j.ones(3,4,1);
|
INDArray arr1 = Nd4j.ones(3,4,1);
|
||||||
|
@ -7656,7 +7677,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp, array);
|
assertEquals(exp, array);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScatterUpdateShortcut_f1(Nd4jBackend backend) {
|
public void testScatterUpdateShortcut_f1(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalStateException.class,() -> {
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
val array = Nd4j.create(DataType.FLOAT, 5, 2);
|
val array = Nd4j.create(DataType.FLOAT, 5, 2);
|
||||||
|
@ -8041,7 +8063,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp, out); //Failing here
|
assertEquals(exp, out); //Failing here
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPullRowsFailure(Nd4jBackend backend) {
|
public void testPullRowsFailure(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
val idxs = new int[]{0,2,3,4};
|
val idxs = new int[]{0,2,3,4};
|
||||||
|
@ -8144,7 +8167,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp1, out1); //This is OK
|
assertEquals(exp1, out1); //This is OK
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutRowValidation(Nd4jBackend backend) {
|
public void testPutRowValidation(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
val matrix = Nd4j.create(5, 10);
|
val matrix = Nd4j.create(5, 10);
|
||||||
|
@ -8155,7 +8179,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutColumnValidation(Nd4jBackend backend) {
|
public void testPutColumnValidation(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
val matrix = Nd4j.create(5, 10);
|
val matrix = Nd4j.create(5, 10);
|
||||||
|
@ -8236,7 +8261,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScalarEq(){
|
public void testScalarEq(Nd4jBackend backend){
|
||||||
INDArray scalarRank2 = Nd4j.scalar(10.0).reshape(1,1);
|
INDArray scalarRank2 = Nd4j.scalar(10.0).reshape(1,1);
|
||||||
INDArray scalarRank1 = Nd4j.scalar(10.0).reshape(1);
|
INDArray scalarRank1 = Nd4j.scalar(10.0).reshape(1);
|
||||||
INDArray scalarRank0 = Nd4j.scalar(10.0);
|
INDArray scalarRank0 = Nd4j.scalar(10.0);
|
||||||
|
@ -8273,7 +8298,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testType1(@TempDir Path testDir) throws IOException {
|
@Disabled
|
||||||
|
public void testType1(Nd4jBackend backend) throws IOException {
|
||||||
for (int i = 0; i < 10; ++i) {
|
for (int i = 0; i < 10; ++i) {
|
||||||
INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100});
|
INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100});
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
|
@ -8295,7 +8321,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOnes(){
|
public void testOnes(Nd4jBackend backend){
|
||||||
INDArray arr = Nd4j.ones();
|
INDArray arr = Nd4j.ones();
|
||||||
INDArray arr2 = Nd4j.ones(DataType.LONG);
|
INDArray arr2 = Nd4j.ones(DataType.LONG);
|
||||||
assertEquals(0, arr.rank());
|
assertEquals(0, arr.rank());
|
||||||
|
@ -8306,7 +8332,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testZeros(){
|
public void testZeros(Nd4jBackend backend){
|
||||||
INDArray arr = Nd4j.zeros();
|
INDArray arr = Nd4j.zeros();
|
||||||
INDArray arr2 = Nd4j.zeros(DataType.LONG);
|
INDArray arr2 = Nd4j.zeros(DataType.LONG);
|
||||||
assertEquals(0, arr.rank());
|
assertEquals(0, arr.rank());
|
||||||
|
@ -8317,7 +8343,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testType2(@TempDir Path testDir) throws IOException {
|
@Disabled
|
||||||
|
public void testType2(Nd4jBackend backend) throws IOException {
|
||||||
for (int i = 0; i < 10; ++i) {
|
for (int i = 0; i < 10; ++i) {
|
||||||
INDArray in1 = Nd4j.ones(DataType.UINT16);
|
INDArray in1 = Nd4j.ones(DataType.UINT16);
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
|
|
|
@ -23,6 +23,7 @@ package org.nd4j.linalg;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -58,7 +59,8 @@ public class ToStringTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testToStringScalars(){
|
@Disabled
|
||||||
|
public void testToStringScalars(Nd4jBackend backend){
|
||||||
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
|
DataType[] dataTypes = new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.BOOL, DataType.INT, DataType.UINT32};
|
||||||
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};
|
String[] strs = new String[]{"1.0000", "1.0000", "true", "1", "1"};
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@Disabled
|
@Disabled
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@ -79,7 +78,6 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@Disabled
|
@Disabled
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@ -100,7 +98,8 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCreateNpy3(Nd4jBackend backend) throws Exception {
|
public void testCreateNpy3(Nd4jBackend backend) throws Exception {
|
||||||
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
|
INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile());
|
||||||
assertEquals(8, arrCreate.length());
|
assertEquals(8, arrCreate.length());
|
||||||
|
@ -111,8 +110,9 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(arrCreate.data().address(), pointer.address());
|
assertEquals(arrCreate.data().address(), pointer.address());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@Disabled // this is endless test
|
@Disabled // this is endless test
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEndlessAllocation(Nd4jBackend backend) {
|
public void testEndlessAllocation(Nd4jBackend backend) {
|
||||||
Nd4j.getEnvironment().setMaxSpecialMemory(1);
|
Nd4j.getEnvironment().setMaxSpecialMemory(1);
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -121,9 +121,10 @@ public class TestNDArrayCreation extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@Disabled("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time")
|
@Disabled("This test is designed to run in isolation. With parallel gc it makes no real sense since allocated amount changes at any time")
|
||||||
public void testAllocationLimits() throws Exception {
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
public void testAllocationLimits(Nd4jBackend backend) throws Exception {
|
||||||
Nd4j.create(1);
|
Nd4j.create(1);
|
||||||
|
|
||||||
val origDeviceLimit = Nd4j.getEnvironment().getDeviceLimit(0);
|
val origDeviceLimit = Nd4j.getEnvironment().getDeviceLimit(0);
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api;
|
package org.nd4j.linalg.api;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||||
|
|
|
@ -59,7 +59,7 @@ public class Level1Test extends BaseNd4jTestWithBackends {
|
||||||
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
INDArray row = matrix.getRow(1);
|
INDArray row = matrix.getRow(1);
|
||||||
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
|
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
|
||||||
assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage());
|
assertEquals(Nd4j.create(new double[] {4, 8}), row,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,6 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
||||||
/**
|
/**
|
||||||
* Testing level1 blas
|
* Testing level1 blas
|
||||||
*/
|
*/
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBlasValidation1(Nd4jBackend backend) {
|
public void testBlasValidation1(Nd4jBackend backend) {
|
||||||
|
@ -89,7 +88,6 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
||||||
/**
|
/**
|
||||||
* Testing level2 blas
|
* Testing level2 blas
|
||||||
*/
|
*/
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBlasValidation2(Nd4jBackend backend) {
|
public void testBlasValidation2(Nd4jBackend backend) {
|
||||||
|
@ -109,7 +107,6 @@ public class DataTypeValidationTests extends BaseNd4jTestWithBackends {
|
||||||
/**
|
/**
|
||||||
* Testing level3 blas
|
* Testing level3 blas
|
||||||
*/
|
*/
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBlasValidation3(Nd4jBackend backend) {
|
public void testBlasValidation3(Nd4jBackend backend) {
|
||||||
|
|
|
@ -88,7 +88,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
float[] d1 = new float[] {1, 2, 3, 4};
|
float[] d1 = new float[] {1, 2, 3, 4};
|
||||||
DataBuffer d = Nd4j.createBuffer(d1);
|
DataBuffer d = Nd4j.createBuffer(d1);
|
||||||
float[] d2 = d.asFloat();
|
float[] d2 = d.asFloat();
|
||||||
assertArrayEquals( d1, d2, 1e-1f,getFailureMessage());
|
assertArrayEquals( d1, d2, 1e-1f,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
d.put(0, 0.0);
|
d.put(0, 0.0);
|
||||||
float[] result = new float[] {0, 2, 3, 4};
|
float[] result = new float[] {0, 2, 3, 4};
|
||||||
d1 = d.asFloat();
|
d1 = d.asFloat();
|
||||||
assertArrayEquals(d1, result, 1e-1f,getFailureMessage());
|
assertArrayEquals(d1, result, 1e-1f,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -156,12 +156,12 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
||||||
float[] get = buffer.getFloatsAt(0, 3);
|
float[] get = buffer.getFloatsAt(0, 3);
|
||||||
float[] data = new float[] {1, 2, 3};
|
float[] data = new float[] {1, 2, 3};
|
||||||
assertArrayEquals(get, data, 1e-1f,getFailureMessage());
|
assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
float[] get2 = buffer.asFloat();
|
float[] get2 = buffer.asFloat();
|
||||||
float[] allData = buffer.getFloatsAt(0, (int) buffer.length());
|
float[] allData = buffer.getFloatsAt(0, (int) buffer.length());
|
||||||
assertArrayEquals(get2, allData, 1e-1f,getFailureMessage());
|
assertArrayEquals(get2, allData, 1e-1f,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -173,13 +173,13 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
DataBuffer buffer = Nd4j.linspace(1, 5, 5).data();
|
||||||
float[] get = buffer.getFloatsAt(1, 3);
|
float[] get = buffer.getFloatsAt(1, 3);
|
||||||
float[] data = new float[] {2, 3, 4};
|
float[] data = new float[] {2, 3, 4};
|
||||||
assertArrayEquals(get, data, 1e-1f,getFailureMessage());
|
assertArrayEquals(get, data, 1e-1f,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
float[] allButLast = new float[] {2, 3, 4, 5};
|
float[] allButLast = new float[] {2, 3, 4, 5};
|
||||||
|
|
||||||
float[] allData = buffer.getFloatsAt(1, (int) buffer.length());
|
float[] allData = buffer.getFloatsAt(1, (int) buffer.length());
|
||||||
assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage());
|
assertArrayEquals(allButLast, allData, 1e-1f,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -190,7 +190,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
public void testAsBytes(Nd4jBackend backend) {
|
public void testAsBytes(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(5);
|
INDArray arr = Nd4j.create(5);
|
||||||
byte[] d = arr.data().asBytes();
|
byte[] d = arr.data().asBytes();
|
||||||
assertEquals(4 * 5, d.length,getFailureMessage());
|
assertEquals(4 * 5, d.length,getFailureMessage(backend));
|
||||||
INDArray rand = Nd4j.rand(3, 3);
|
INDArray rand = Nd4j.rand(3, 3);
|
||||||
rand.data().asBytes();
|
rand.data().asBytes();
|
||||||
|
|
||||||
|
|
|
@ -20,26 +20,18 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.indexing;
|
package org.nd4j.linalg.api.indexing;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.common.util.ArrayUtil;
|
||||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
import org.nd4j.linalg.indexing.*;
|
||||||
import org.nd4j.linalg.indexing.IntervalIndex;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndexAll;
|
|
||||||
import org.nd4j.linalg.indexing.NewAxis;
|
|
||||||
import org.nd4j.linalg.indexing.PointIndex;
|
|
||||||
import org.nd4j.linalg.indexing.SpecifiedIndex;
|
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import org.nd4j.common.util.ArrayUtil;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
@ -58,7 +50,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNegativeBounds() {
|
public void testNegativeBounds(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
|
INDArray arr = Nd4j.linspace(1,10,10, DataType.DOUBLE).reshape(2,5);
|
||||||
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
|
INDArrayIndex interval = NDArrayIndex.interval(0,1,-2,arr.size(1));
|
||||||
INDArray get = arr.get(NDArrayIndex.all(),interval);
|
INDArray get = arr.get(NDArrayIndex.all(),interval);
|
||||||
|
@ -71,7 +63,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNewAxis() {
|
public void testNewAxis(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
|
INDArray arr = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 2, 2);
|
||||||
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
|
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.all(), newAxis(), newAxis(), all());
|
||||||
long[] shapeAssertion = {3, 2, 1, 1, 2};
|
long[] shapeAssertion = {3, 2, 1, 1, 2};
|
||||||
|
@ -81,7 +73,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void broadcastBug() {
|
public void broadcastBug(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
|
INDArray a = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new int[] {2, 2});
|
||||||
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
final INDArray col = a.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
||||||
|
|
||||||
|
@ -93,7 +85,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIntervalsIn3D() {
|
public void testIntervalsIn3D(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
||||||
INDArray rest = arr.get(interval(1, 2), interval(0, 2), interval(0, 2));
|
INDArray rest = arr.get(interval(1, 2), interval(0, 2), interval(0, 2));
|
||||||
|
@ -103,7 +95,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSmallInterval() {
|
public void testSmallInterval(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE);
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2);
|
||||||
INDArray rest = arr.get(interval(1, 2), all(), all());
|
INDArray rest = arr.get(interval(1, 2), all(), all());
|
||||||
|
@ -113,7 +105,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAllWithNewAxisAndInterval() {
|
public void testAllWithNewAxisAndInterval(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||||
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
|
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9},}).reshape(1, 1, 3);
|
||||||
|
|
||||||
|
@ -123,7 +115,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAllWithNewAxisInMiddle() {
|
public void testAllWithNewAxisInMiddle(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||||
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
|
INDArray assertion2 = Nd4j.create(new double[][] {{7, 8, 9}, {10, 11, 12}}).reshape(1, 2, 3);
|
||||||
|
|
||||||
|
@ -133,7 +125,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAllWithNewAxis() {
|
public void testAllWithNewAxis(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
INDArray arr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||||
INDArray get = arr.get(newAxis(), all(), point(1));
|
INDArray get = arr.get(newAxis(), all(), point(1));
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{4, 5, 6}, {10, 11, 12}, {16, 17, 18}, {22, 23, 24}})
|
INDArray assertion = Nd4j.create(new double[][] {{4, 5, 6}, {10, 11, 12}, {16, 17, 18}, {22, 23, 24}})
|
||||||
|
@ -144,7 +136,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIndexingWithMmul() {
|
public void testIndexingWithMmul(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
|
INDArray a = Nd4j.linspace(1, 9, 9, DataType.DOUBLE).reshape(3, 3);
|
||||||
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
INDArray b = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
||||||
// System.out.println(b);
|
// System.out.println(b);
|
||||||
|
@ -156,7 +148,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPointPointInterval() {
|
public void testPointPointInterval(Nd4jBackend backend) {
|
||||||
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
|
INDArray wholeArr = Nd4j.linspace(1, 36, 36, DataType.DOUBLE).reshape(4, 3, 3);
|
||||||
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
|
INDArray get = wholeArr.get(point(0), interval(1, 3), interval(1, 3));
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{5, 6}, {8, 9}});
|
INDArray assertion = Nd4j.create(new double[][] {{5, 6}, {8, 9}});
|
||||||
|
@ -166,7 +158,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIntervalLowerBound() {
|
public void testIntervalLowerBound(Nd4jBackend backend) {
|
||||||
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
INDArray wholeArr = Nd4j.linspace(1, 24, 24, DataType.DOUBLE).reshape(4, 2, 3);
|
||||||
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
|
INDArray subarray = wholeArr.get(interval(1, 3), NDArrayIndex.point(0), NDArrayIndex.indices(0, 2));
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}});
|
INDArray assertion = Nd4j.create(new double[][] {{7, 9}, {13, 15}});
|
||||||
|
@ -178,7 +170,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetPointRowVector() {
|
public void testGetPointRowVector(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
|
INDArray arr = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE).reshape(1, -1);
|
||||||
|
|
||||||
INDArray arr2 = arr.get(point(0), interval(0, 100));
|
INDArray arr2 = arr.get(point(0), interval(0, 100));
|
||||||
|
@ -189,7 +181,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSpecifiedIndexVector() {
|
public void testSpecifiedIndexVector(Nd4jBackend backend) {
|
||||||
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
|
INDArray rootMatrix = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4);
|
||||||
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
|
INDArray threeD = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(2, 2, 2, 2);
|
||||||
INDArray get = rootMatrix.get(all(), new SpecifiedIndex(0, 2));
|
INDArray get = rootMatrix.get(all(), new SpecifiedIndex(0, 2));
|
||||||
|
@ -207,7 +199,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutRowIndexing() {
|
public void testPutRowIndexing(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.ones(1, 10);
|
INDArray arr = Nd4j.ones(1, 10);
|
||||||
INDArray row = Nd4j.create(1, 10);
|
INDArray row = Nd4j.create(1, 10);
|
||||||
|
|
||||||
|
@ -218,7 +210,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVectorIndexing2() {
|
public void testVectorIndexing2(Nd4jBackend backend) {
|
||||||
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
|
INDArray wholeVector = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).get(interval(1, 2, 3, true));
|
||||||
INDArray assertion = Nd4j.create(new double[] {2, 4});
|
INDArray assertion = Nd4j.create(new double[] {2, 4});
|
||||||
assertEquals(assertion, wholeVector);
|
assertEquals(assertion, wholeVector);
|
||||||
|
@ -234,7 +226,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testOffsetsC() {
|
public void testOffsetsC(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
|
assertEquals(3, NDArrayIndex.offset(arr, 1, 1));
|
||||||
assertEquals(3, NDArrayIndex.offset(arr, point(1), point(1)));
|
assertEquals(3, NDArrayIndex.offset(arr, point(1), point(1)));
|
||||||
|
@ -251,7 +243,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIndexFor() {
|
public void testIndexFor(Nd4jBackend backend) {
|
||||||
long[] shape = {1, 2};
|
long[] shape = {1, 2};
|
||||||
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
|
INDArrayIndex[] indexes = NDArrayIndex.indexesFor(shape);
|
||||||
for (int i = 0; i < indexes.length; i++) {
|
for (int i = 0; i < indexes.length; i++) {
|
||||||
|
@ -261,7 +253,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetScalar() {
|
public void testGetScalar(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE);
|
||||||
INDArray d = arr.get(point(1));
|
INDArray d = arr.get(point(1));
|
||||||
assertTrue(d.isScalar());
|
assertTrue(d.isScalar());
|
||||||
|
@ -271,7 +263,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testVectorIndexing() {
|
public void testVectorIndexing(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
|
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).reshape(1, -1);
|
||||||
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
|
INDArray assertion = Nd4j.create(new double[] {2, 3, 4, 5});
|
||||||
INDArray viewTest = arr.get(point(0), interval(1, 5));
|
INDArray viewTest = arr.get(point(0), interval(1, 5));
|
||||||
|
@ -280,7 +272,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNegativeIndices() {
|
public void testNegativeIndices(Nd4jBackend backend) {
|
||||||
INDArray test = Nd4j.create(10, 10, 10);
|
INDArray test = Nd4j.create(10, 10, 10);
|
||||||
test.putScalar(new int[] {0, 0, -1}, 1.0);
|
test.putScalar(new int[] {0, 0, -1}, 1.0);
|
||||||
assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber());
|
assertEquals(1.0, test.getScalar(0, 0, -1).sumNumber());
|
||||||
|
@ -288,7 +280,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetIndices2d() {
|
public void testGetIndices2d(Nd4jBackend backend) {
|
||||||
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
|
INDArray twoByTwo = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(3, 2);
|
||||||
INDArray firstRow = twoByTwo.getRow(0);
|
INDArray firstRow = twoByTwo.getRow(0);
|
||||||
INDArray secondRow = twoByTwo.getRow(1);
|
INDArray secondRow = twoByTwo.getRow(1);
|
||||||
|
@ -307,7 +299,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetRow() {
|
public void testGetRow(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
|
INDArray in = Nd4j.linspace(0, 14, 15, DataType.DOUBLE).reshape(3, 5);
|
||||||
int[] toGet = {0, 1};
|
int[] toGet = {0, 1};
|
||||||
|
@ -325,7 +317,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetRowEdgeCase() {
|
public void testGetRowEdgeCase(Nd4jBackend backend) {
|
||||||
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
INDArray rowVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1);
|
||||||
INDArray get = rowVec.getRow(0); //Returning shape [1,1]
|
INDArray get = rowVec.getRow(0); //Returning shape [1,1]
|
||||||
|
|
||||||
|
@ -335,7 +327,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetColumnEdgeCase() {
|
public void testGetColumnEdgeCase(Nd4jBackend backend) {
|
||||||
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
|
INDArray colVec = Nd4j.linspace(1, 5, 5, DataType.DOUBLE).reshape(1, -1).transpose();
|
||||||
INDArray get = colVec.getColumn(0); //Returning shape [1,1]
|
INDArray get = colVec.getColumn(0); //Returning shape [1,1]
|
||||||
|
|
||||||
|
@ -345,7 +337,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConcatColumns() {
|
public void testConcatColumns(Nd4jBackend backend) {
|
||||||
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
|
INDArray input1 = Nd4j.zeros(2, 1).castTo(DataType.DOUBLE);
|
||||||
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
|
INDArray input2 = Nd4j.ones(2, 1).castTo(DataType.DOUBLE);
|
||||||
INDArray concat = Nd4j.concat(1, input1, input2);
|
INDArray concat = Nd4j.concat(1, input1, input2);
|
||||||
|
@ -355,7 +347,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetIndicesVector() {
|
public void testGetIndicesVector(Nd4jBackend backend) {
|
||||||
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
|
INDArray line = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(1, -1);
|
||||||
INDArray test = Nd4j.create(new double[] {2, 3});
|
INDArray test = Nd4j.create(new double[] {2, 3});
|
||||||
INDArray result = line.get(point(0), interval(1, 3));
|
INDArray result = line.get(point(0), interval(1, 3));
|
||||||
|
@ -364,7 +356,7 @@ public class IndexingTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testArangeMul() {
|
public void testArangeMul(Nd4jBackend backend) {
|
||||||
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
|
INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE);
|
||||||
INDArrayIndex index = interval(0, 2);
|
INDArrayIndex index = interval(0, 2);
|
||||||
INDArray get = arange.get(index, index);
|
INDArray get = arange.get(index, index);
|
||||||
|
|
|
@ -46,12 +46,13 @@ import java.util.List;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
|
||||||
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
|
public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void compareAfterWrite(Nd4jBackend backend) throws Exception {
|
||||||
int [] ranksToCheck = new int[] {0,1,2,3,4};
|
int [] ranksToCheck = new int[] {0,1,2,3,4};
|
||||||
for (int i = 0; i < ranksToCheck.length; i++) {
|
for (int i = 0; i < ranksToCheck.length; i++) {
|
||||||
// log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
// log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
||||||
|
@ -82,7 +83,7 @@ public class TestNdArrReadWriteTxt extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNd4jReadWriteText(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testNd4jReadWriteText(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
|
|
@ -38,11 +38,11 @@ import static org.nd4j.linalg.api.ndarray.TestNdArrReadWriteTxt.compareArrays;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
|
||||||
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
|
public class TestNdArrReadWriteTxtC extends BaseNd4jTestWithBackends {
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void compareAfterWrite(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void compareAfterWrite(Nd4jBackend backend) throws Exception {
|
||||||
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
|
int[] ranksToCheck = new int[]{0, 1, 2, 3, 4};
|
||||||
for (int i = 0; i < ranksToCheck.length; i++) {
|
for (int i = 0; i < ranksToCheck.length; i++) {
|
||||||
log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
log.info("Checking read write arrays with rank " + ranksToCheck[i]);
|
||||||
|
|
|
@ -22,6 +22,7 @@ package org.nd4j.linalg.broadcast;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -135,7 +136,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(e, z);
|
assertEquals(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void basicBroadcastFailureTest_1(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_1(Nd4jBackend backend) {
|
||||||
|
@ -146,7 +146,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void basicBroadcastFailureTest_2(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_2(Nd4jBackend backend) {
|
||||||
|
@ -158,7 +157,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void basicBroadcastFailureTest_3(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_3(Nd4jBackend backend) {
|
||||||
|
@ -170,16 +168,15 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void basicBroadcastFailureTest_4(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_4(Nd4jBackend backend) {
|
||||||
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
||||||
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
||||||
val z = x.addi(y);
|
val z = x.addi(y);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void basicBroadcastFailureTest_5(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_5(Nd4jBackend backend) {
|
||||||
|
@ -191,7 +188,6 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void basicBroadcastFailureTest_6(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_6(Nd4jBackend backend) {
|
||||||
|
@ -249,9 +245,9 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(y, z);
|
assertEquals(y, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void emptyBroadcastTest_2(Nd4jBackend backend) {
|
public void emptyBroadcastTest_2(Nd4jBackend backend) {
|
||||||
val x = Nd4j.create(DataType.FLOAT, 1, 2);
|
val x = Nd4j.create(DataType.FLOAT, 1, 2);
|
||||||
val y = Nd4j.create(DataType.FLOAT, 0, 2);
|
val y = Nd4j.create(DataType.FLOAT, 0, 2);
|
||||||
|
|
|
@ -37,7 +37,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
public class CompressionMagicTests extends BaseNd4jTestWithBackends {
|
public class CompressionMagicTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void setUp(Nd4jBackend backend) {
|
public void setUp() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,7 @@ import java.util.Set;
|
||||||
|
|
||||||
public class DeconvTests extends BaseNd4jTestWithBackends {
|
public class DeconvTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
|
@ -56,7 +57,7 @@ public class DeconvTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void compareKeras(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void compareKeras(Nd4jBackend backend) throws Exception {
|
||||||
File newFolder = testDir.toFile();
|
File newFolder = testDir.toFile();
|
||||||
new ClassPathResource("keras/deconv/").copyDirectory(newFolder);
|
new ClassPathResource("keras/deconv/").copyDirectory(newFolder);
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,8 @@ public class SpecialTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScalarShuffle1(Nd4jBackend backend) {
|
public void testScalarShuffle1(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
List<DataSet> listData = new ArrayList<>();
|
List<DataSet> listData = new ArrayList<>();
|
||||||
|
|
|
@ -195,7 +195,8 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp, arrayX);
|
assertEquals(exp, arrayX);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInplaceOp1(Nd4jBackend backend) {
|
public void testInplaceOp1(Nd4jBackend backend) {
|
||||||
assertThrows(ND4JIllegalStateException.class,() -> {
|
assertThrows(ND4JIllegalStateException.class,() -> {
|
||||||
val arrayX = Nd4j.create(10, 10);
|
val arrayX = Nd4j.create(10, 10);
|
||||||
|
|
|
@ -41,10 +41,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
|
public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBalance(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testBalance(Nd4jBackend backend) throws Exception {
|
||||||
DataSetIterator iterator = new IrisDataSetIterator(10, 150);
|
DataSetIterator iterator = new IrisDataSetIterator(10, 150);
|
||||||
|
|
||||||
File minibatches = new File(testDir.toFile(),"mini-batch-dir");
|
File minibatches = new File(testDir.toFile(),"mini-batch-dir");
|
||||||
|
@ -62,7 +63,7 @@ public class BalanceMinibatchesTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMiniBatchBalanced(@TempDir Path testDir,Nd4jBackend backend) throws Exception {
|
public void testMiniBatchBalanced(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
int miniBatchSize = 100;
|
int miniBatchSize = 100;
|
||||||
DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150);
|
DataSetIterator iterator = new IrisDataSetIterator(miniBatchSize, 150);
|
||||||
|
|
|
@ -52,6 +52,8 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.*;
|
||||||
|
|
||||||
public class DataSetTest extends BaseNd4jTestWithBackends {
|
public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testViewIterator(Nd4jBackend backend) {
|
public void testViewIterator(Nd4jBackend backend) {
|
||||||
|
@ -116,7 +118,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(train.getTrain().getLabels().length(), 6);
|
assertEquals(train.getTrain().getLabels().length(), 6);
|
||||||
|
|
||||||
SplitTestAndTrain train2 = data.splitTestAndTrain(6, new Random(1));
|
SplitTestAndTrain train2 = data.splitTestAndTrain(6, new Random(1));
|
||||||
assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage());
|
assertEquals(train.getTrain().getFeatures(), train2.getTrain().getFeatures(),getFailureMessage(backend));
|
||||||
|
|
||||||
DataSet x0 = new IrisDataSetIterator(150, 150).next();
|
DataSet x0 = new IrisDataSetIterator(150, 150).next();
|
||||||
SplitTestAndTrain testAndTrain = x0.splitTestAndTrain(10);
|
SplitTestAndTrain testAndTrain = x0.splitTestAndTrain(10);
|
||||||
|
@ -154,13 +156,13 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLabelCounts(Nd4jBackend backend) {
|
public void testLabelCounts(Nd4jBackend backend) {
|
||||||
DataSet x0 = new IrisDataSetIterator(150, 150).next();
|
DataSet x0 = new IrisDataSetIterator(150, 150).next();
|
||||||
assertEquals(0, x0.get(0).outcome(),getFailureMessage());
|
assertEquals(0, x0.get(0).outcome(),getFailureMessage(backend));
|
||||||
assertEquals( 0, x0.get(1).outcome(),getFailureMessage());
|
assertEquals( 0, x0.get(1).outcome(),getFailureMessage(backend));
|
||||||
assertEquals(2, x0.get(149).outcome(),getFailureMessage());
|
assertEquals(2, x0.get(149).outcome(),getFailureMessage(backend));
|
||||||
Map<Integer, Double> counts = x0.labelCounts();
|
Map<Integer, Double> counts = x0.labelCounts();
|
||||||
assertEquals(50, counts.get(0), 1e-1,getFailureMessage());
|
assertEquals(50, counts.get(0), 1e-1,getFailureMessage(backend));
|
||||||
assertEquals(50, counts.get(1), 1e-1,getFailureMessage());
|
assertEquals(50, counts.get(1), 1e-1,getFailureMessage(backend));
|
||||||
assertEquals(50, counts.get(2), 1e-1,getFailureMessage());
|
assertEquals(50, counts.get(2), 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1098,7 +1100,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend backend) throws IOException {
|
public void testDataSetMetaDataSerialization(Nd4jBackend backend) throws IOException {
|
||||||
|
|
||||||
for(boolean withMeta : new boolean[]{false, true}) {
|
for(boolean withMeta : new boolean[]{false, true}) {
|
||||||
// create simple data set with meta data object
|
// create simple data set with meta data object
|
||||||
|
@ -1129,7 +1131,7 @@ public class DataSetTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiDataSetMetaDataSerialization(@TempDir Path testDir,Nd4jBackend nd4jBackend) throws IOException {
|
public void testMultiDataSetMetaDataSerialization(Nd4jBackend nd4jBackend) throws IOException {
|
||||||
|
|
||||||
for(boolean withMeta : new boolean[]{false, true}) {
|
for(boolean withMeta : new boolean[]{false, true}) {
|
||||||
// create simple data set with meta data object
|
// create simple data set with meta data object
|
||||||
|
|
|
@ -106,7 +106,8 @@ public class KFoldIteratorTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void checkCornerCaseException(Nd4jBackend backend) {
|
public void checkCornerCaseException(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1),
|
DataSet allData = new DataSet(Nd4j.linspace(1,99,99, DataType.DOUBLE).reshape(-1, 1),
|
||||||
|
|
|
@ -21,27 +21,25 @@
|
||||||
package org.nd4j.linalg.dataset;
|
package org.nd4j.linalg.dataset;
|
||||||
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTestWithBackends {
|
public class MiniBatchFileDataSetIteratorTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMiniBatches(@TempDir Path testDir) throws Exception {
|
public void testMiniBatches(Nd4jBackend backend) throws Exception {
|
||||||
DataSet load = new IrisDataSetIterator(150, 150).next();
|
DataSet load = new IrisDataSetIterator(150, 150).next();
|
||||||
final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.toFile());
|
final MiniBatchFileDataSetIterator iter = new MiniBatchFileDataSetIterator(load, 10, false, testDir.toFile());
|
||||||
while (iter.hasNext())
|
while (iter.hasNext())
|
||||||
|
|
|
@ -39,7 +39,6 @@ public class CompositeDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_preConditionsIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
public void when_preConditionsIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||||
|
|
|
@ -41,7 +41,6 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_originalHeightIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_originalHeightIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
|
@ -51,7 +50,6 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_originalWidthIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_originalWidthIsZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
|
@ -61,7 +59,6 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_yStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_yStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
|
@ -71,7 +68,6 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_xStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_xStartIsNegative_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
|
@ -81,7 +77,6 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
|
@ -91,7 +86,6 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
|
@ -101,7 +95,6 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException(Nd4jBackend backend) {
|
||||||
|
@ -111,7 +104,6 @@ public class CropAndResizeDataSetPreProcessorTest extends BaseNd4jTestWithBacken
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||||
|
|
|
@ -39,7 +39,8 @@ public class PermuteDataSetPreProcessorTest extends BaseNd4jTestWithBackends {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||||
assertThrows(NullPointerException.class,() -> {
|
assertThrows(NullPointerException.class,() -> {
|
||||||
// Assemble
|
// Assemble
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.dataset.api.preprocessor;
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||||
|
@ -39,7 +38,8 @@ public class RGBtoGrayscaleDataSetPreProcessorTest extends BaseNd4jTestWithBacke
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
public void when_dataSetIsNull_expect_NullPointerException(Nd4jBackend backend) {
|
||||||
assertThrows(NullPointerException.class,() -> {
|
assertThrows(NullPointerException.class,() -> {
|
||||||
// Assemble
|
// Assemble
|
||||||
|
|
|
@ -139,7 +139,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
|
||||||
INDArray actualResult = data.mean(0);
|
INDArray actualResult = data.mean(0);
|
||||||
INDArray expectedResult = Nd4j.create(new double[] {3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6.,
|
INDArray expectedResult = Nd4j.create(new double[] {3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6.,
|
||||||
6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6.}, new int[] {2, 4, 4});
|
6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6.}, new int[] {2, 4, 4});
|
||||||
assertEquals(expectedResult, actualResult,getFailureMessage());
|
assertEquals(expectedResult, actualResult,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
|
||||||
INDArray actualResult = data.var(false, 0);
|
INDArray actualResult = data.var(false, 0);
|
||||||
INDArray expectedResult = Nd4j.create(new double[] {1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4.,
|
INDArray expectedResult = Nd4j.create(new double[] {1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4.,
|
||||||
4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4.}, new long[] {2, 4, 4});
|
4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4.}, new long[] {2, 4, 4});
|
||||||
assertEquals(expectedResult, actualResult,getFailureMessage());
|
assertEquals(expectedResult, actualResult,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
|
|
@ -83,7 +83,6 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAccessException_1(Nd4jBackend backend) {
|
public void testAccessException_1(Nd4jBackend backend) {
|
||||||
|
@ -96,7 +95,6 @@ public class CloseableTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAccessException_2(Nd4jBackend backend) {
|
public void testAccessException_2(Nd4jBackend backend) {
|
||||||
|
|
|
@ -384,7 +384,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(exp, arrayZ);
|
assertEquals(exp, arrayZ);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
|
||||||
public void testTypesValidation_1(Nd4jBackend backend) {
|
public void testTypesValidation_1(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG);
|
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.LONG);
|
||||||
|
@ -397,7 +399,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTypesValidation_2(Nd4jBackend backend) {
|
public void testTypesValidation_2(Nd4jBackend backend) {
|
||||||
assertThrows(RuntimeException.class,() -> {
|
assertThrows(RuntimeException.class,() -> {
|
||||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||||
|
@ -412,7 +415,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTypesValidation_3(Nd4jBackend backend) {
|
public void testTypesValidation_3(Nd4jBackend backend) {
|
||||||
assertThrows(RuntimeException.class,() -> {
|
assertThrows(RuntimeException.class,() -> {
|
||||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||||
|
@ -422,6 +426,8 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTypesValidation_4(Nd4jBackend backend) {
|
public void testTypesValidation_4(Nd4jBackend backend) {
|
||||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||||
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.DOUBLE);
|
val arrayY = Nd4j.create(new int[]{1, 0, 0, 4}, new long[]{4}, DataType.DOUBLE);
|
||||||
|
@ -485,7 +491,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBoolFloatCast2(){
|
public void testBoolFloatCast2(Nd4jBackend backend){
|
||||||
val first = Nd4j.zeros(DataType.FLOAT, 3, 5000);
|
val first = Nd4j.zeros(DataType.FLOAT, 3, 5000);
|
||||||
INDArray asBool = first.castTo(DataType.BOOL);
|
INDArray asBool = first.castTo(DataType.BOOL);
|
||||||
INDArray not = Transforms.not(asBool); //
|
INDArray not = Transforms.not(asBool); //
|
||||||
|
@ -516,7 +522,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAssignScalarSimple(){
|
public void testAssignScalarSimple(Nd4jBackend backend){
|
||||||
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
||||||
INDArray arr = Nd4j.scalar(dt, 10.0);
|
INDArray arr = Nd4j.scalar(dt, 10.0);
|
||||||
arr.assign(2.0);
|
arr.assign(2.0);
|
||||||
|
@ -526,7 +532,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSimple(){
|
public void testSimple(Nd4jBackend backend){
|
||||||
Nd4j.create(1);
|
Nd4j.create(1);
|
||||||
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) {
|
for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, DataType.LONG}) {
|
||||||
// System.out.println("----- " + dt + " -----");
|
// System.out.println("----- " + dt + " -----");
|
||||||
|
@ -551,7 +557,7 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testWorkspaceBool(){
|
public void testWorkspaceBool(Nd4jBackend backend){
|
||||||
val conf = WorkspaceConfiguration.builder().minSize(10 * 1024 * 1024)
|
val conf = WorkspaceConfiguration.builder().minSize(10 * 1024 * 1024)
|
||||||
.overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE)
|
.overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE)
|
||||||
.policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL)
|
.policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL)
|
||||||
|
@ -574,8 +580,9 @@ public class MixedDataTypesTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
@Disabled("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
public void testArrayCreationFromPointer(Nd4jBackend backend) {
|
public void testArrayCreationFromPointer(Nd4jBackend backend) {
|
||||||
val source = Nd4j.create(new double[]{1, 2, 3, 4, 5});
|
val source = Nd4j.create(new double[]{1, 2, 3, 4, 5});
|
||||||
|
|
||||||
|
|
|
@ -40,13 +40,13 @@ public class NativeBlasTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void setUp(Nd4jBackend backend) {
|
public void setUp() {
|
||||||
Nd4j.getExecutioner().enableDebugMode(true);
|
Nd4j.getExecutioner().enableDebugMode(true);
|
||||||
Nd4j.getExecutioner().enableVerboseMode(true);
|
Nd4j.getExecutioner().enableVerboseMode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
public void setDown(Nd4jBackend backend) {
|
public void setDown() {
|
||||||
Nd4j.getExecutioner().enableDebugMode(false);
|
Nd4j.getExecutioner().enableDebugMode(false);
|
||||||
Nd4j.getExecutioner().enableVerboseMode(false);
|
Nd4j.getExecutioner().enableVerboseMode(false);
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,18 +77,18 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
|
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
|
||||||
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
|
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4, 5});
|
||||||
double sim = Transforms.cosineSim(vec1, vec2);
|
double sim = Transforms.cosineSim(vec1, vec2);
|
||||||
assertEquals( 1, sim, 1e-1,getFailureMessage());
|
assertEquals( 1, sim, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCosineDistance(){
|
public void testCosineDistance(Nd4jBackend backend){
|
||||||
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3});
|
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3});
|
||||||
INDArray vec2 = Nd4j.create(new float[] {3, 5, 7});
|
INDArray vec2 = Nd4j.create(new float[] {3, 5, 7});
|
||||||
// 1-17*sqrt(2/581)
|
// 1-17*sqrt(2/581)
|
||||||
double distance = Transforms.cosineDistance(vec1, vec2);
|
double distance = Transforms.cosineDistance(vec1, vec2);
|
||||||
assertEquals(0.0025851, distance, 1e-7,getFailureMessage());
|
assertEquals(0.0025851, distance, 1e-7,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -97,7 +97,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray arr = Nd4j.create(new double[] {55, 55});
|
INDArray arr = Nd4j.create(new double[] {55, 55});
|
||||||
INDArray arr2 = Nd4j.create(new double[] {60, 60});
|
INDArray arr2 = Nd4j.create(new double[] {60, 60});
|
||||||
double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).z().getDouble(0);
|
double result = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(arr, arr2)).z().getDouble(0);
|
||||||
assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage());
|
assertEquals(7.0710678118654755, result, 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -137,7 +137,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi();
|
INDArray scalarMax = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).negi();
|
||||||
INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6);
|
INDArray postMax = Nd4j.ones(DataType.DOUBLE, 6);
|
||||||
Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1));
|
Nd4j.getExecutioner().exec(new ScalarMax(scalarMax, 1));
|
||||||
assertEquals(scalarMax, postMax,getFailureMessage());
|
assertEquals(scalarMax, postMax,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -147,14 +147,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1));
|
Nd4j.getExecutioner().exec(new SetRange(linspace, 0, 1));
|
||||||
for (int i = 0; i < linspace.length(); i++) {
|
for (int i = 0; i < linspace.length(); i++) {
|
||||||
double val = linspace.getDouble(i);
|
double val = linspace.getDouble(i);
|
||||||
assertTrue( val >= 0 && val <= 1,getFailureMessage());
|
assertTrue( val >= 0 && val <= 1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
INDArray linspace2 = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||||
Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4));
|
Nd4j.getExecutioner().exec(new SetRange(linspace2, 2, 4));
|
||||||
for (int i = 0; i < linspace2.length(); i++) {
|
for (int i = 0; i < linspace2.length(); i++) {
|
||||||
double val = linspace2.getDouble(i);
|
double val = linspace2.getDouble(i);
|
||||||
assertTrue( val >= 2 && val <= 4,getFailureMessage());
|
assertTrue( val >= 2 && val <= 4,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,7 +163,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
public void testNormMax(Nd4jBackend backend) {
|
public void testNormMax(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).z().getDouble(0);
|
double normMax = Nd4j.getExecutioner().execAndReturn(new NormMax(arr)).z().getDouble(0);
|
||||||
assertEquals(4, normMax, 1e-1,getFailureMessage());
|
assertEquals(4, normMax, 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -187,7 +187,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
public void testNorm2(Nd4jBackend backend) {
|
public void testNorm2(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray arr = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).z().getDouble(0);
|
double norm2 = Nd4j.getExecutioner().execAndReturn(new Norm2(arr)).z().getDouble(0);
|
||||||
assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage());
|
assertEquals(5.4772255750516612, norm2, 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -198,7 +198,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray xDup = x.dup();
|
INDArray xDup = x.dup();
|
||||||
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
|
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
|
||||||
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
|
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
|
||||||
assertEquals(solution, x,getFailureMessage());
|
assertEquals(solution, x,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -221,13 +221,13 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray xDup = x.dup();
|
INDArray xDup = x.dup();
|
||||||
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
|
INDArray solution = Nd4j.valueArrayOf(5, 2.0);
|
||||||
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
|
opExecutioner.exec(new AddOp(new INDArray[]{x, xDup},new INDArray[]{x}));
|
||||||
assertEquals(solution, x,getFailureMessage());
|
assertEquals(solution, x,getFailureMessage(backend));
|
||||||
Sum acc = new Sum(x.dup());
|
Sum acc = new Sum(x.dup());
|
||||||
opExecutioner.exec(acc);
|
opExecutioner.exec(acc);
|
||||||
assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(10.0, acc.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
Prod prod = new Prod(x.dup());
|
Prod prod = new Prod(x.dup());
|
||||||
opExecutioner.exec(prod);
|
opExecutioner.exec(prod);
|
||||||
assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(32.0, prod.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -275,7 +275,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
Variance variance = new Variance(x.dup(), true);
|
Variance variance = new Variance(x.dup(), true);
|
||||||
opExecutioner.exec(variance);
|
opExecutioner.exec(variance);
|
||||||
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -284,14 +284,14 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIamax(Nd4jBackend backend) {
|
public void testIamax(Nd4jBackend backend) {
|
||||||
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||||
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage());
|
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIamax2(Nd4jBackend backend) {
|
public void testIamax2(Nd4jBackend backend) {
|
||||||
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||||
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage());
|
assertEquals( 3, Nd4j.getBlasWrapper().iamax(linspace),getFailureMessage(backend));
|
||||||
val op = new ArgAmax(linspace);
|
val op = new ArgAmax(linspace);
|
||||||
|
|
||||||
int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0);
|
int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0);
|
||||||
|
@ -307,11 +307,11 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
Mean mean = new Mean(x);
|
Mean mean = new Mean(x);
|
||||||
opExecutioner.exec(mean);
|
opExecutioner.exec(mean);
|
||||||
assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals( 3.0, mean.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
Variance variance = new Variance(x.dup(), true);
|
Variance variance = new Variance(x.dup(), true);
|
||||||
opExecutioner.exec(variance);
|
opExecutioner.exec(variance);
|
||||||
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals( 2.5, variance.getFinalResult().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -321,7 +321,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
||||||
val softMax = new SoftMax(arr);
|
val softMax = new SoftMax(arr);
|
||||||
opExecutioner.exec((CustomOp) softMax);
|
opExecutioner.exec((CustomOp) softMax);
|
||||||
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -332,7 +332,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
Pow pow = new Pow(oneThroughSix, 2);
|
Pow pow = new Pow(oneThroughSix, 2);
|
||||||
Nd4j.getExecutioner().exec(pow);
|
Nd4j.getExecutioner().exec(pow);
|
||||||
INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36});
|
INDArray answer = Nd4j.create(new double[] {1, 4, 9, 16, 25, 36});
|
||||||
assertEquals(answer, pow.z(),getFailureMessage());
|
assertEquals(answer, pow.z(),getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -384,7 +384,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
Log log = new Log(slice);
|
Log log = new Log(slice);
|
||||||
opExecutioner.exec(log);
|
opExecutioner.exec(log);
|
||||||
INDArray assertion = Nd4j.create(new double[] {0., 1.09861229, 1.60943791});
|
INDArray assertion = Nd4j.create(new double[] {0., 1.09861229, 1.60943791});
|
||||||
assertEquals(assertion, slice,getFailureMessage());
|
assertEquals(assertion, slice,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -572,7 +572,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
expected[i] = (float) Math.exp(slice.getDouble(i));
|
expected[i] = (float) Math.exp(slice.getDouble(i));
|
||||||
Exp exp = new Exp(slice);
|
Exp exp = new Exp(slice);
|
||||||
opExecutioner.exec(exp);
|
opExecutioner.exec(exp);
|
||||||
assertEquals( Nd4j.create(expected), slice,getFailureMessage());
|
assertEquals( Nd4j.create(expected), slice,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -582,7 +582,7 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
||||||
val softMax = new SoftMax(arr);
|
val softMax = new SoftMax(arr);
|
||||||
opExecutioner.exec((CustomOp) softMax);
|
opExecutioner.exec((CustomOp) softMax);
|
||||||
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage());
|
assertEquals(1.0, softMax.outputArguments().get(0).sumNumber().doubleValue(), 1e-1,getFailureMessage(backend));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue