Add large test tags, ensure that small runs finish, get rid of test timeouts
parent
652b854083
commit
3e60302e8c
|
@ -31,7 +31,7 @@ jobs:
|
||||||
protoc --version
|
protoc --version
|
||||||
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DexcludedGroups="long-running-tests,large-resources" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test --fail-never
|
||||||
|
|
||||||
windows-x86_64:
|
windows-x86_64:
|
||||||
runs-on: windows-2019
|
runs-on: windows-2019
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
set "PATH=C:\msys64\usr\bin;%PATH%"
|
set "PATH=C:\msys64\usr\bin;%PATH%"
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
|
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DexcludedGroups="long-running-tests,large-resources" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test --fail-never
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,5 +22,6 @@ jobs:
|
||||||
cmake --version
|
cmake --version
|
||||||
protoc --version
|
protoc --version
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -DexcludedGroups=long-running-tests -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pnd4j-tests-cpu clean test
|
mvn -DexcludedGroups="long-running-tests,large-resources" -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pnd4j-tests-cpu clean test
|
||||||
|
mvn -Ptestresources -Pnd4j-tests-cpu -Dtest.offheap.size=14g -Dtest.heap.size=6g clean test
|
||||||
|
|
||||||
|
|
|
@ -34,5 +34,5 @@ jobs:
|
||||||
cmake --version
|
cmake --version
|
||||||
protoc --version
|
protoc --version
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Pnd4j-tests-cpu --also-make clean test
|
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DexcludedGroups="long-running-tests,large-resources" -Pnd4j-tests-cpu --also-make clean test --fail-never
|
||||||
|
|
||||||
|
|
|
@ -34,5 +34,6 @@ jobs:
|
||||||
cmake --version
|
cmake --version
|
||||||
protoc --version
|
protoc --version
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -DexcludedGroups=long-running-tests -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pnd4j-tests-cuda clean test
|
mvn -DexcludedGroups="long-running-tests,large-resources" -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pnd4j-tests-cuda clean test --fail-never
|
||||||
|
mvn -Ptestresources -Pnd4j-tests-cuda -Dtest.offheap.size=14g -Dtest.heap.size=6g clean test --fail-never
|
||||||
|
|
||||||
|
|
|
@ -35,5 +35,5 @@ jobs:
|
||||||
protoc --version
|
protoc --version
|
||||||
bash ./change-cuda-versions.sh 11.2
|
bash ./change-cuda-versions.sh 11.2
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-cuda-11.2,:samediff-import,:libnd4j" -Dlibnd4j.helper=cudnn -Ptest-nd4j-cuda --also-make -Dlibnd4j.chip=cuda clean test
|
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-cuda-11.2,:samediff-import,:libnd4j" -Dlibnd4j.helper=cudnn -Ptest-nd4j-cuda --also-make -Dlibnd4j.chip=cuda clean test --fail-never
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,17 @@ A few kinds of tags exist:
|
||||||
7. RNG: (rng) for RNG related tests
|
7. RNG: (rng) for RNG related tests
|
||||||
8. Samediff:(samediff) samediff related tests
|
8. Samediff:(samediff) samediff related tests
|
||||||
9. Training related functionality
|
9. Training related functionality
|
||||||
|
10. long-running-tests: The longer running tests that take a longer execution time
|
||||||
|
11. large-resources: tests requiring a large amount of ram/cpu (>= 2g up to 16g)
|
||||||
|
|
||||||
|
|
||||||
|
New maven properties for maven surefire:
|
||||||
|
test.offheap.size: tunes off heap size for javacpp
|
||||||
|
test.heap.size: tunes heap size of test jvms
|
||||||
|
|
||||||
|
|
||||||
|
Auto tuning the number of CPU cores for tests relative to the number of CPUs present
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Consequences
|
## Consequences
|
||||||
|
|
|
@ -58,6 +58,7 @@ import java.io.File;
|
||||||
import java.io.FileInputStream;
|
import java.io.FileInputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
@ -171,7 +172,8 @@ public class ArrowConverter {
|
||||||
ByteBuffer direct = ByteBuffer.allocateDirect(fieldVector.getDataBuffer().capacity());
|
ByteBuffer direct = ByteBuffer.allocateDirect(fieldVector.getDataBuffer().capacity());
|
||||||
direct.order(ByteOrder.nativeOrder());
|
direct.order(ByteOrder.nativeOrder());
|
||||||
fieldVector.getDataBuffer().getBytes(0,direct);
|
fieldVector.getDataBuffer().getBytes(0,direct);
|
||||||
direct.rewind();
|
Buffer buffer1 = (Buffer) direct;
|
||||||
|
buffer1.rewind();
|
||||||
switch(type) {
|
switch(type) {
|
||||||
case Integer:
|
case Integer:
|
||||||
buffer = Nd4j.createBuffer(direct, DataType.INT,cols,0);
|
buffer = Nd4j.createBuffer(direct, DataType.INT,cols,0);
|
||||||
|
|
|
@ -119,6 +119,7 @@
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
<configuration>
|
||||||
|
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
|
||||||
<classpathDependencyExcludes>
|
<classpathDependencyExcludes>
|
||||||
<classpathDependencyExclude>com.google.android:android
|
<classpathDependencyExclude>com.google.android:android
|
||||||
</classpathDependencyExclude>
|
</classpathDependencyExclude>
|
||||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -19,31 +19,26 @@
|
||||||
*/
|
*/
|
||||||
package org.deeplearning4j.datasets;
|
package org.deeplearning4j.datasets;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.base.MnistFetcher;
|
|
||||||
import org.deeplearning4j.common.resources.DL4JResources;
|
import org.deeplearning4j.common.resources.DL4JResources;
|
||||||
|
import org.deeplearning4j.datasets.base.MnistFetcher;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.junit.jupiter.api.*;
|
import org.junit.jupiter.api.*;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
import org.nd4j.common.tests.tags.TagNames;
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@DisplayName("Mnist Fetcher Test")
|
@DisplayName("Mnist Fetcher Test")
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@ -65,6 +60,9 @@ class MnistFetcherTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("Test Mnist")
|
@DisplayName("Test Mnist")
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
|
@Tag(TagNames.FILE_IO)
|
||||||
void testMnist() throws Exception {
|
void testMnist() throws Exception {
|
||||||
MnistDataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1);
|
MnistDataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1);
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
@ -91,6 +89,9 @@ class MnistFetcherTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("Test Mnist Data Fetcher")
|
@DisplayName("Test Mnist Data Fetcher")
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
|
@Tag(TagNames.FILE_IO)
|
||||||
void testMnistDataFetcher() throws Exception {
|
void testMnistDataFetcher() throws Exception {
|
||||||
MnistFetcher mnistFetcher = new MnistFetcher();
|
MnistFetcher mnistFetcher = new MnistFetcher();
|
||||||
File mnistDir = mnistFetcher.downloadAndUntar();
|
File mnistDir = mnistFetcher.downloadAndUntar();
|
||||||
|
@ -99,6 +100,9 @@ class MnistFetcherTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
|
@Tag(TagNames.FILE_IO)
|
||||||
public void testMnistSubset() throws Exception {
|
public void testMnistSubset() throws Exception {
|
||||||
final int numExamples = 100;
|
final int numExamples = 100;
|
||||||
MnistDataSetIterator iter1 = new MnistDataSetIterator(10, numExamples, false, true, true, 123);
|
MnistDataSetIterator iter1 = new MnistDataSetIterator(10, numExamples, false, true, true, 123);
|
||||||
|
@ -144,6 +148,9 @@ class MnistFetcherTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("Test Subset Repeatability")
|
@DisplayName("Test Subset Repeatability")
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
|
@Tag(TagNames.FILE_IO)
|
||||||
void testSubsetRepeatability() throws Exception {
|
void testSubsetRepeatability() throws Exception {
|
||||||
MnistDataSetIterator it = new MnistDataSetIterator(1, 1, false, false, true, 0);
|
MnistDataSetIterator it = new MnistDataSetIterator(1, 1, false, false, true, 0);
|
||||||
DataSet d1 = it.next();
|
DataSet d1 = it.next();
|
||||||
|
|
|
@ -51,6 +51,7 @@ public class TestEmnistDataSetIterator extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public void testEmnistDataSetIterator() throws Exception {
|
public void testEmnistDataSetIterator() throws Exception {
|
||||||
int batchSize = 128;
|
int batchSize = 128;
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,8 @@ import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
import org.junit.jupiter.api.*;
|
import org.junit.jupiter.api.*;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
import org.nd4j.common.tests.tags.TagNames;
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -1717,8 +1719,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
MultiLayerTest.CheckModelsListener listener = new MultiLayerTest.CheckModelsListener();
|
MultiLayerTest.CheckModelsListener listener = new MultiLayerTest.CheckModelsListener();
|
||||||
net.setListeners(listener);
|
net.setListeners(listener);
|
||||||
|
|
||||||
INDArray f = Nd4j.create(1,10);
|
INDArray f = Nd4j.create(DataType.DOUBLE,1,10);
|
||||||
INDArray l = Nd4j.create(1,10);
|
INDArray l = Nd4j.create(DataType.DOUBLE,1,10);
|
||||||
DataSet ds = new DataSet(f,l);
|
DataSet ds = new DataSet(f,l);
|
||||||
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f,l);
|
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f,l);
|
||||||
|
|
||||||
|
@ -2117,9 +2119,10 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
|
@Tag(TagNames.NEEDS_VERIFY)
|
||||||
|
@Disabled
|
||||||
public void testCompGraphInputReuse() {
|
public void testCompGraphInputReuse() {
|
||||||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
|
||||||
|
|
||||||
int inputSize = 5;
|
int inputSize = 5;
|
||||||
int outputSize = 6;
|
int outputSize = 6;
|
||||||
int layerSize = 3;
|
int layerSize = 3;
|
||||||
|
@ -2134,7 +2137,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
.setOutputs("out")
|
.setOutputs("out")
|
||||||
.addLayer("0",new DenseLayer.Builder().nIn(inputSize).nOut(layerSize).build(),"in")
|
.addLayer("0",new DenseLayer.Builder().nIn(inputSize).nOut(layerSize).build(),"in")
|
||||||
.addVertex("combine", new MergeVertex(), "0", "0", "0")
|
.addVertex("combine", new MergeVertex(), "0", "0", "0")
|
||||||
.addLayer("out",new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(3*layerSize).nOut(outputSize)
|
.addLayer("out",new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(3*layerSize)
|
||||||
|
.nOut(outputSize)
|
||||||
.activation(Activation.SIGMOID).build(),"combine")
|
.activation(Activation.SIGMOID).build(),"combine")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -2143,8 +2147,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
||||||
int dataSize = 11;
|
int dataSize = 11;
|
||||||
INDArray features = Nd4j.rand(new int[] {dataSize, inputSize});
|
INDArray features = Nd4j.rand(DataType.DOUBLE,new int[] {dataSize, inputSize});
|
||||||
INDArray labels = Nd4j.rand(new int[] {dataSize, outputSize});
|
INDArray labels = Nd4j.rand(DataType.DOUBLE,new int[] {dataSize, outputSize});
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{features})
|
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{features})
|
||||||
.labels(new INDArray[]{labels}));
|
.labels(new INDArray[]{labels}));
|
||||||
|
|
|
@ -23,8 +23,11 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.distribution.*;
|
import org.deeplearning4j.nn.conf.distribution.*;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
import org.junit.jupiter.api.*;
|
import org.junit.jupiter.api.*;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
import org.nd4j.common.tests.tags.TagNames;
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -69,14 +72,19 @@ class LegacyWeightInitTest extends BaseDL4JTest {
|
||||||
final long[] shape = { 5, 5 };
|
final long[] shape = { 5, 5 };
|
||||||
final long fanIn = shape[0];
|
final long fanIn = shape[0];
|
||||||
final long fanOut = shape[1];
|
final long fanOut = shape[1];
|
||||||
final INDArray inLegacy = Nd4j.create(fanIn * fanOut);
|
final INDArray inLegacy = Nd4j.create(DataType.DOUBLE,fanIn * fanOut);
|
||||||
final INDArray inTest = inLegacy.dup();
|
final INDArray inTest = inLegacy.dup();
|
||||||
for (WeightInit legacyWi : WeightInit.values()) {
|
for (WeightInit legacyWi : WeightInit.values()) {
|
||||||
if (legacyWi != WeightInit.DISTRIBUTION) {
|
if (legacyWi != WeightInit.DISTRIBUTION) {
|
||||||
Nd4j.getRandom().setSeed(SEED);
|
Nd4j.getRandom().setSeed(SEED);
|
||||||
final INDArray expected = WeightInitUtil.initWeights(fanIn, fanOut, shape, legacyWi, null, inLegacy);
|
final INDArray expected = WeightInitUtil.
|
||||||
|
initWeights(fanIn, fanOut, shape, legacyWi, null, inLegacy)
|
||||||
|
.castTo(DataType.DOUBLE);
|
||||||
Nd4j.getRandom().setSeed(SEED);
|
Nd4j.getRandom().setSeed(SEED);
|
||||||
final INDArray actual = legacyWi.getWeightInitFunction().init(fanIn, fanOut, shape, WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER, inTest);
|
final INDArray actual = legacyWi.getWeightInitFunction()
|
||||||
|
.init(fanIn, fanOut, shape,
|
||||||
|
WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER, inTest)
|
||||||
|
.castTo(DataType.DOUBLE);
|
||||||
assertArrayEquals(shape, actual.shape(),"Incorrect shape for " + legacyWi + "!");
|
assertArrayEquals(shape, actual.shape(),"Incorrect shape for " + legacyWi + "!");
|
||||||
assertEquals( expected, actual,"Incorrect weight initialization for " + legacyWi + "!");
|
assertEquals( expected, actual,"Incorrect weight initialization for " + legacyWi + "!");
|
||||||
}
|
}
|
||||||
|
@ -88,17 +96,24 @@ class LegacyWeightInitTest extends BaseDL4JTest {
|
||||||
*/
|
*/
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("Init Params From Distribution")
|
@DisplayName("Init Params From Distribution")
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
|
@Disabled(TagNames.NEEDS_VERIFY)
|
||||||
void initParamsFromDistribution() {
|
void initParamsFromDistribution() {
|
||||||
// To make identity happy
|
// To make identity happy
|
||||||
final long[] shape = { 3, 7 };
|
final long[] shape = { 3, 7 };
|
||||||
final long fanIn = shape[0];
|
final long fanIn = shape[0];
|
||||||
final long fanOut = shape[1];
|
final long fanOut = shape[1];
|
||||||
final INDArray inLegacy = Nd4j.create(fanIn * fanOut);
|
final INDArray inLegacy = Nd4j.create(DataType.DOUBLE,fanIn * fanOut);
|
||||||
final INDArray inTest = inLegacy.dup();
|
final INDArray inTest = inLegacy.dup();
|
||||||
for (Distribution dist : distributions) {
|
for (Distribution dist : distributions) {
|
||||||
Nd4j.getRandom().setSeed(SEED);
|
Nd4j.getRandom().setSeed(SEED);
|
||||||
final INDArray expected = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.DISTRIBUTION, Distributions.createDistribution(dist), inLegacy);
|
final INDArray expected = WeightInitUtil
|
||||||
final INDArray actual = new WeightInitDistribution(dist).init(fanIn, fanOut, shape, WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER, inTest);
|
.initWeights(fanIn, fanOut, shape, WeightInit.DISTRIBUTION,
|
||||||
|
Distributions.createDistribution(dist), inLegacy)
|
||||||
|
.castTo(DataType.DOUBLE);
|
||||||
|
final INDArray actual = new WeightInitDistribution(dist)
|
||||||
|
.init(fanIn, fanOut, shape, WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER,
|
||||||
|
inTest).castTo(DataType.DOUBLE);
|
||||||
assertArrayEquals(shape, actual.shape(),"Incorrect shape for " + dist.getClass().getSimpleName() + "!");
|
assertArrayEquals(shape, actual.shape(),"Incorrect shape for " + dist.getClass().getSimpleName() + "!");
|
||||||
assertEquals( expected, actual,"Incorrect weight initialization for " + dist.getClass().getSimpleName() + "!");
|
assertEquals( expected, actual,"Incorrect weight initialization for " + dist.getClass().getSimpleName() + "!");
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,6 +34,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
import org.nd4j.common.tests.tags.TagNames;
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -56,14 +58,14 @@ public class RandomTests extends BaseDL4JTest {
|
||||||
*
|
*
|
||||||
* @throws Exception
|
* @throws Exception
|
||||||
*/
|
*/
|
||||||
@Test
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
public void testModelInitialParamsEquality1() throws Exception {
|
public void testModelInitialParamsEquality1() throws Exception {
|
||||||
final List<Model> models = new CopyOnWriteArrayList<>();
|
final List<Model> models = new CopyOnWriteArrayList<>();
|
||||||
|
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
Thread thread = new Thread(new Runnable() {
|
Thread thread = new Thread(() -> {
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(119) // Training iterations as above
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(119) // Training iterations as above
|
||||||
.l2(0.0005)
|
.l2(0.0005)
|
||||||
//.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
|
//.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
|
||||||
|
@ -91,7 +93,6 @@ public class RandomTests extends BaseDL4JTest {
|
||||||
network.init();
|
network.init();
|
||||||
|
|
||||||
models.add(network);
|
models.add(network);
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
thread.start();
|
thread.start();
|
||||||
|
|
|
@ -47,8 +47,9 @@
|
||||||
<configuration>
|
<configuration>
|
||||||
<forkCount>${cpu.core.count}</forkCount>
|
<forkCount>${cpu.core.count}</forkCount>
|
||||||
<reuseForks>false</reuseForks>
|
<reuseForks>false</reuseForks>
|
||||||
|
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
|
||||||
<argLine>-Ddtype=float -Dfile.encoding=UTF-8
|
<argLine>-Ddtype=float -Dfile.encoding=UTF-8
|
||||||
-Dtest.solr.allowed.securerandom=NativePRNG
|
-Dtest.solr.allowed.securerandom=NativePRNG -Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}
|
||||||
</argLine>
|
</argLine>
|
||||||
<includes>
|
<includes>
|
||||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
<!-- Default setting only runs tests that start/end with "Test" -->
|
||||||
|
|
|
@ -48,6 +48,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
@DisplayName("Tuple Stream Data Set Iterator Test")
|
@DisplayName("Tuple Stream Data Set Iterator Test")
|
||||||
@Tag(TagNames.SOLR)
|
@Tag(TagNames.SOLR)
|
||||||
@Tag(TagNames.DIST_SYSTEMS)
|
@Tag(TagNames.DIST_SYSTEMS)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
class TupleStreamDataSetIteratorTest extends SolrCloudTestCase {
|
class TupleStreamDataSetIteratorTest extends SolrCloudTestCase {
|
||||||
|
|
||||||
static {
|
static {
|
||||||
|
|
|
@ -41,7 +41,8 @@
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
<configuration>
|
||||||
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g
|
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
|
||||||
|
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx${test.heap.size}
|
||||||
-Dtest.solr.allowed.securerandom=NativePRNG
|
-Dtest.solr.allowed.securerandom=NativePRNG
|
||||||
</argLine>
|
</argLine>
|
||||||
<includes>
|
<includes>
|
||||||
|
|
|
@ -76,6 +76,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@Tag(TagNames.FILE_IO)
|
@Tag(TagNames.FILE_IO)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
public class SequenceVectorsTest extends BaseDL4JTest {
|
public class SequenceVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
protected static final Logger logger = LoggerFactory.getLogger(SequenceVectorsTest.class);
|
protected static final Logger logger = LoggerFactory.getLogger(SequenceVectorsTest.class);
|
||||||
|
|
|
@ -424,12 +424,7 @@ public class GradientCheckUtil {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"Invalid labels arrays: expect " + c.net.getNumOutputArrays() + " outputs");
|
"Invalid labels arrays: expect " + c.net.getNumOutputArrays() + " outputs");
|
||||||
|
|
||||||
DataType dataType = DataTypeUtil.getDtypeFromContext();
|
|
||||||
if (dataType != DataType.DOUBLE) {
|
|
||||||
throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision ("
|
|
||||||
+ "is: " + dataType + "). Double precision must be used for gradient checks. Set "
|
|
||||||
+ "DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
|
|
||||||
}
|
|
||||||
|
|
||||||
DataType netDataType = c.net.getConfiguration().getDataType();
|
DataType netDataType = c.net.getConfiguration().getDataType();
|
||||||
if (netDataType != DataType.DOUBLE) {
|
if (netDataType != DataType.DOUBLE) {
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
package org.deeplearning4j.spark.models.sequencevectors;
|
package org.deeplearning4j.spark.models.sequencevectors;
|
||||||
|
|
||||||
import com.sun.jna.Platform;
|
import com.sun.jna.Platform;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
@ -35,9 +37,12 @@ import org.deeplearning4j.spark.models.word2vec.SparkWord2VecTest;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||||
import org.junit.jupiter.api.*;
|
import org.junit.jupiter.api.*;
|
||||||
import org.nd4j.common.primitives.Counter;
|
import org.nd4j.common.primitives.Counter;
|
||||||
|
import org.nd4j.common.resources.Downloader;
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
import org.nd4j.common.tests.tags.TagNames;
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.net.URI;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -47,6 +52,7 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||||
@Tag(TagNames.SPARK)
|
@Tag(TagNames.SPARK)
|
||||||
@Tag(TagNames.DIST_SYSTEMS)
|
@Tag(TagNames.DIST_SYSTEMS)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Slf4j
|
||||||
public class SparkSequenceVectorsTest extends BaseDL4JTest {
|
public class SparkSequenceVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -57,6 +63,27 @@ public class SparkSequenceVectorsTest extends BaseDL4JTest {
|
||||||
protected static List<Sequence<VocabWord>> sequencesCyclic;
|
protected static List<Sequence<VocabWord>> sequencesCyclic;
|
||||||
private JavaSparkContext sc;
|
private JavaSparkContext sc;
|
||||||
|
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
@SneakyThrows
|
||||||
|
public static void beforeAll() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp");
|
||||||
|
File binDir = new File(hadoopHome,"bin");
|
||||||
|
if(!binDir.exists())
|
||||||
|
binDir.mkdirs();
|
||||||
|
File outputFile = new File(binDir,"winutils.exe");
|
||||||
|
if(!outputFile.exists()) {
|
||||||
|
log.info("Fixing spark for windows");
|
||||||
|
Downloader.download("winutils.exe",
|
||||||
|
URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(),
|
||||||
|
outputFile,"db24b404d2331a1bec7443336a5171f1",3);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void setUp() throws Exception {
|
public void setUp() throws Exception {
|
||||||
if (sequencesCyclic == null) {
|
if (sequencesCyclic == null) {
|
||||||
|
|
|
@ -20,6 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.models.word2vec;
|
package org.deeplearning4j.spark.models.word2vec;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
@ -35,11 +38,14 @@ import org.deeplearning4j.spark.models.sequencevectors.export.SparkModelExporter
|
||||||
import org.deeplearning4j.spark.models.sequencevectors.learning.elements.SparkSkipGram;
|
import org.deeplearning4j.spark.models.sequencevectors.learning.elements.SparkSkipGram;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||||
import org.junit.jupiter.api.*;
|
import org.junit.jupiter.api.*;
|
||||||
|
import org.nd4j.common.resources.Downloader;
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
import org.nd4j.common.tests.tags.TagNames;
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
|
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.net.URI;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -48,6 +54,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Tag(TagNames.SPARK)
|
@Tag(TagNames.SPARK)
|
||||||
@Tag(TagNames.DIST_SYSTEMS)
|
@Tag(TagNames.DIST_SYSTEMS)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Slf4j
|
||||||
public class SparkWord2VecTest extends BaseDL4JTest {
|
public class SparkWord2VecTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -58,6 +65,27 @@ public class SparkWord2VecTest extends BaseDL4JTest {
|
||||||
private static List<String> sentences;
|
private static List<String> sentences;
|
||||||
private JavaSparkContext sc;
|
private JavaSparkContext sc;
|
||||||
|
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
@SneakyThrows
|
||||||
|
public static void beforeAll() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp");
|
||||||
|
File binDir = new File(hadoopHome,"bin");
|
||||||
|
if(!binDir.exists())
|
||||||
|
binDir.mkdirs();
|
||||||
|
File outputFile = new File(binDir,"winutils.exe");
|
||||||
|
if(!outputFile.exists()) {
|
||||||
|
log.info("Fixing spark for windows");
|
||||||
|
Downloader.download("winutils.exe",
|
||||||
|
URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(),
|
||||||
|
outputFile,"db24b404d2331a1bec7443336a5171f1",3);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void setUp() throws Exception {
|
public void setUp() throws Exception {
|
||||||
if (sentences == null) {
|
if (sentences == null) {
|
||||||
|
|
|
@ -21,11 +21,15 @@
|
||||||
package org.deeplearning4j.spark.models.embeddings.word2vec;
|
package org.deeplearning4j.spark.models.embeddings.word2vec;
|
||||||
|
|
||||||
import com.sun.jna.Platform;
|
import com.sun.jna.Platform;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
|
||||||
|
|
||||||
|
import org.deeplearning4j.common.resources.DL4JResources;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
@ -41,11 +45,14 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFac
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.common.resources.Downloader;
|
||||||
|
import org.nd4j.common.resources.strumpf.StrumpfResolver;
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
import org.nd4j.common.tests.tags.TagNames;
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.net.URI;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -53,21 +60,37 @@ import java.util.Collection;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@Disabled
|
|
||||||
@Tag(TagNames.FILE_IO)
|
@Tag(TagNames.FILE_IO)
|
||||||
@Tag(TagNames.SPARK)
|
@Tag(TagNames.SPARK)
|
||||||
@Tag(TagNames.DIST_SYSTEMS)
|
@Tag(TagNames.DIST_SYSTEMS)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Slf4j
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public class Word2VecTest {
|
public class Word2VecTest {
|
||||||
|
@BeforeAll
|
||||||
|
@SneakyThrows
|
||||||
|
public static void beforeAll() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp");
|
||||||
|
File binDir = new File(hadoopHome,"bin");
|
||||||
|
if(!binDir.exists())
|
||||||
|
binDir.mkdirs();
|
||||||
|
File outputFile = new File(binDir,"winutils.exe");
|
||||||
|
if(!outputFile.exists()) {
|
||||||
|
log.info("Fixing spark for windows");
|
||||||
|
Downloader.download("winutils.exe",
|
||||||
|
URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(),
|
||||||
|
outputFile,"db24b404d2331a1bec7443336a5171f1",3);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConcepts(@TempDir Path testDir) throws Exception {
|
public void testConcepts(@TempDir Path testDir) throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// These are all default values for word2vec
|
// These are all default values for word2vec
|
||||||
SparkConf sparkConf = new SparkConf().setMaster("local[8]")
|
SparkConf sparkConf = new SparkConf().setMaster("local[8]")
|
||||||
.set("spark.driver.host", "localhost")
|
.set("spark.driver.host", "localhost")
|
||||||
|
|
|
@ -20,21 +20,50 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.text;
|
package org.deeplearning4j.spark.text;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables;
|
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables;
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.nd4j.common.resources.Downloader;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
|
import java.net.URI;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@Slf4j
|
||||||
public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable {
|
public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable {
|
||||||
protected transient JavaSparkContext sc;
|
protected transient JavaSparkContext sc;
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
@SneakyThrows
|
||||||
|
public static void beforeAll() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp");
|
||||||
|
File binDir = new File(hadoopHome,"bin");
|
||||||
|
if(!binDir.exists())
|
||||||
|
binDir.mkdirs();
|
||||||
|
File outputFile = new File(binDir,"winutils.exe");
|
||||||
|
if(!outputFile.exists()) {
|
||||||
|
log.info("Fixing spark for windows");
|
||||||
|
Downloader.download("winutils.exe",
|
||||||
|
URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(),
|
||||||
|
outputFile,"db24b404d2331a1bec7443336a5171f1",3);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 120000L;
|
return 120000L;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.spark.text;
|
package org.deeplearning4j.spark.text;
|
||||||
|
|
||||||
import com.sun.jna.Platform;
|
import com.sun.jna.Platform;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaPairRDD;
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
@ -35,10 +36,8 @@ import org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec;
|
||||||
import org.deeplearning4j.spark.text.functions.CountCumSum;
|
import org.deeplearning4j.spark.text.functions.CountCumSum;
|
||||||
import org.deeplearning4j.spark.text.functions.TextPipeline;
|
import org.deeplearning4j.spark.text.functions.TextPipeline;
|
||||||
import org.deeplearning4j.text.stopwords.StopWords;
|
import org.deeplearning4j.text.stopwords.StopWords;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.*;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.nd4j.common.resources.Downloader;
|
||||||
import org.junit.jupiter.api.Tag;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
import org.nd4j.common.tests.tags.TagNames;
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -48,6 +47,8 @@ import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import scala.Tuple2;
|
import scala.Tuple2;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.net.URI;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
|
||||||
|
@ -74,6 +75,26 @@ public class TextPipelineTest extends BaseSparkTest {
|
||||||
return sc.parallelize(sentenceList, 2);
|
return sc.parallelize(sentenceList, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
@SneakyThrows
|
||||||
|
public static void beforeAll() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp");
|
||||||
|
File binDir = new File(hadoopHome,"bin");
|
||||||
|
if(!binDir.exists())
|
||||||
|
binDir.mkdirs();
|
||||||
|
File outputFile = new File(binDir,"winutils.exe");
|
||||||
|
if(!outputFile.exists()) {
|
||||||
|
log.info("Fixing spark for windows");
|
||||||
|
Downloader.download("winutils.exe",
|
||||||
|
URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(),
|
||||||
|
outputFile,"db24b404d2331a1bec7443336a5171f1",3);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void before() throws Exception {
|
public void before() throws Exception {
|
||||||
conf = new SparkConf().setMaster("local[4]").setAppName("sparktest").set("spark.driver.host", "localhost");
|
conf = new SparkConf().setMaster("local[4]").setAppName("sparktest").set("spark.driver.host", "localhost");
|
||||||
|
@ -102,10 +123,6 @@ public class TextPipelineTest extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTokenizer() throws Exception {
|
public void testTokenizer() throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
|
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
|
||||||
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
|
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
|
||||||
|
|
|
@ -20,6 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.spark.parameterserver;
|
package org.deeplearning4j.spark.parameterserver;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
@ -29,7 +32,9 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
|
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
|
||||||
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.nd4j.common.resources.Downloader;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
@ -37,12 +42,14 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.net.URI;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable {
|
public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable {
|
||||||
protected transient JavaSparkContext sc;
|
protected transient JavaSparkContext sc;
|
||||||
protected transient INDArray labels;
|
protected transient INDArray labels;
|
||||||
|
@ -60,6 +67,27 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
|
||||||
return 120000L;
|
return 120000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
@SneakyThrows
|
||||||
|
public static void beforeAll() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp");
|
||||||
|
File binDir = new File(hadoopHome,"bin");
|
||||||
|
if(!binDir.exists())
|
||||||
|
binDir.mkdirs();
|
||||||
|
File outputFile = new File(binDir,"winutils.exe");
|
||||||
|
if(!outputFile.exists()) {
|
||||||
|
log.info("Fixing spark for windows");
|
||||||
|
Downloader.download("winutils.exe",
|
||||||
|
URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(),
|
||||||
|
outputFile,"db24b404d2331a1bec7443336a5171f1",3);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void before() {
|
public void before() {
|
||||||
|
|
||||||
|
|
|
@ -40,10 +40,6 @@ public class SharedTrainingAccumulationFunctionTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAccumulation1() throws Exception {
|
public void testAccumulation1() throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
INDArray updates1 = Nd4j.create(1000).assign(1.0);
|
INDArray updates1 = Nd4j.create(1000).assign(1.0);
|
||||||
INDArray updates2 = Nd4j.create(1000).assign(2.0);
|
INDArray updates2 = Nd4j.create(1000).assign(2.0);
|
||||||
INDArray expUpdates = Nd4j.create(1000).assign(3.0);
|
INDArray expUpdates = Nd4j.create(1000).assign(3.0);
|
||||||
|
|
|
@ -43,10 +43,6 @@ public class SharedTrainingAggregateFunctionTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAggregate1() throws Exception {
|
public void testAggregate1() throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
INDArray updates1 = Nd4j.create(1000).assign(1.0);
|
INDArray updates1 = Nd4j.create(1000).assign(1.0);
|
||||||
INDArray updates2 = Nd4j.create(1000).assign(2.0);
|
INDArray updates2 = Nd4j.create(1000).assign(2.0);
|
||||||
INDArray expUpdates = Nd4j.create(1000).assign(3.0);
|
INDArray expUpdates = Nd4j.create(1000).assign(3.0);
|
||||||
|
|
|
@ -21,15 +21,21 @@
|
||||||
package org.deeplearning4j.spark.parameterserver.iterators;
|
package org.deeplearning4j.spark.parameterserver.iterators;
|
||||||
|
|
||||||
import com.sun.jna.Platform;
|
import com.sun.jna.Platform;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.common.resources.Downloader;
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
import org.nd4j.common.tests.tags.TagNames;
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.net.URI;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -39,17 +45,35 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
@Tag(TagNames.SPARK)
|
@Tag(TagNames.SPARK)
|
||||||
@Tag(TagNames.DIST_SYSTEMS)
|
@Tag(TagNames.DIST_SYSTEMS)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Slf4j
|
||||||
public class VirtualDataSetIteratorTest {
|
public class VirtualDataSetIteratorTest {
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void setUp() throws Exception {}
|
public void setUp() throws Exception {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
@SneakyThrows
|
||||||
|
public static void beforeAll() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp");
|
||||||
|
File binDir = new File(hadoopHome,"bin");
|
||||||
|
if(!binDir.exists())
|
||||||
|
binDir.mkdirs();
|
||||||
|
File outputFile = new File(binDir,"winutils.exe");
|
||||||
|
if(!outputFile.exists()) {
|
||||||
|
log.info("Fixing spark for windows");
|
||||||
|
Downloader.download("winutils.exe",
|
||||||
|
URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(),
|
||||||
|
outputFile,"db24b404d2331a1bec7443336a5171f1",3);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSimple1() throws Exception {
|
public void testSimple1() throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
List<Iterator<DataSet>> iterators = new ArrayList<>();
|
List<Iterator<DataSet>> iterators = new ArrayList<>();
|
||||||
|
|
||||||
List<DataSet> first = new ArrayList<>();
|
List<DataSet> first = new ArrayList<>();
|
||||||
|
|
|
@ -21,12 +21,18 @@
|
||||||
package org.deeplearning4j.spark.parameterserver.iterators;
|
package org.deeplearning4j.spark.parameterserver.iterators;
|
||||||
|
|
||||||
import com.sun.jna.Platform;
|
import com.sun.jna.Platform;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.common.resources.Downloader;
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
import org.nd4j.common.tests.tags.TagNames;
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.net.URI;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -35,18 +41,36 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
@Tag(TagNames.SPARK)
|
@Tag(TagNames.SPARK)
|
||||||
@Tag(TagNames.DIST_SYSTEMS)
|
@Tag(TagNames.DIST_SYSTEMS)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Slf4j
|
||||||
public class VirtualIteratorTest {
|
public class VirtualIteratorTest {
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void setUp() throws Exception {
|
public void setUp() throws Exception {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
@SneakyThrows
|
||||||
|
public static void beforeAll() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp");
|
||||||
|
File binDir = new File(hadoopHome,"bin");
|
||||||
|
if(!binDir.exists())
|
||||||
|
binDir.mkdirs();
|
||||||
|
File outputFile = new File(binDir,"winutils.exe");
|
||||||
|
if(!outputFile.exists()) {
|
||||||
|
log.info("Fixing spark for windows");
|
||||||
|
Downloader.download("winutils.exe",
|
||||||
|
URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(),
|
||||||
|
outputFile,"db24b404d2331a1bec7443336a5171f1",3);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testIteration1() throws Exception {
|
public void testIteration1() throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
List<Integer> integers = new ArrayList<>();
|
List<Integer> integers = new ArrayList<>();
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
integers.add(i);
|
integers.add(i);
|
||||||
|
|
|
@ -21,19 +21,24 @@
|
||||||
package org.deeplearning4j.spark.parameterserver.modelimport.elephas;
|
package org.deeplearning4j.spark.parameterserver.modelimport.elephas;
|
||||||
|
|
||||||
import com.sun.jna.Platform;
|
import com.sun.jna.Platform;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
|
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
|
||||||
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
|
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
|
||||||
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
|
||||||
import org.deeplearning4j.spark.parameterserver.BaseSparkTest;
|
import org.deeplearning4j.spark.parameterserver.BaseSparkTest;
|
||||||
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
|
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
import org.nd4j.common.resources.Downloader;
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
import org.nd4j.common.tests.tags.TagNames;
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.net.URI;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.StandardCopyOption;
|
import java.nio.file.StandardCopyOption;
|
||||||
|
|
||||||
|
@ -43,14 +48,32 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
@Tag(TagNames.SPARK)
|
@Tag(TagNames.SPARK)
|
||||||
@Tag(TagNames.DIST_SYSTEMS)
|
@Tag(TagNames.DIST_SYSTEMS)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Slf4j
|
||||||
public class TestElephasImport extends BaseSparkTest {
|
public class TestElephasImport extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
@SneakyThrows
|
||||||
|
public static void beforeAll() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp");
|
||||||
|
File binDir = new File(hadoopHome,"bin");
|
||||||
|
if(!binDir.exists())
|
||||||
|
binDir.mkdirs();
|
||||||
|
File outputFile = new File(binDir,"winutils.exe");
|
||||||
|
if(!outputFile.exists()) {
|
||||||
|
log.info("Fixing spark for windows");
|
||||||
|
Downloader.download("winutils.exe",
|
||||||
|
URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(),
|
||||||
|
outputFile,"db24b404d2331a1bec7443336a5171f1",3);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testElephasSequentialImport() throws Exception {
|
public void testElephasSequentialImport() throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
String modelPath = "modelimport/elephas/elephas_sequential.h5";
|
String modelPath = "modelimport/elephas/elephas_sequential.h5";
|
||||||
SparkDl4jMultiLayer model = importElephasSequential(sc, modelPath);
|
SparkDl4jMultiLayer model = importElephasSequential(sc, modelPath);
|
||||||
// System.out.println(model.getNetwork().summary());
|
// System.out.println(model.getNetwork().summary());
|
||||||
|
|
|
@ -44,7 +44,6 @@ import org.deeplearning4j.spark.api.RDDTrainingApproach;
|
||||||
import org.deeplearning4j.spark.api.TrainingMaster;
|
import org.deeplearning4j.spark.api.TrainingMaster;
|
||||||
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
|
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
|
||||||
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
|
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
|
||||||
import org.deeplearning4j.spark.parameterserver.BaseSparkTest;
|
|
||||||
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
|
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
|
||||||
|
@ -75,6 +74,7 @@ import java.util.*;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
import org.deeplearning4j.spark.parameterserver.BaseSparkTest;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
//@Disabled("AB 2019/05/21 - Failing - Issue #7657")
|
//@Disabled("AB 2019/05/21 - Failing - Issue #7657")
|
||||||
|
@ -82,6 +82,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Tag(TagNames.SPARK)
|
@Tag(TagNames.SPARK)
|
||||||
@Tag(TagNames.DIST_SYSTEMS)
|
@Tag(TagNames.DIST_SYSTEMS)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public class GradientSharingTrainingTest extends BaseSparkTest {
|
public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
||||||
|
@ -339,11 +341,12 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test @Disabled
|
@Test
|
||||||
public void testEpochUpdating(@TempDir Path testDir) throws Exception {
|
public void testEpochUpdating(@TempDir Path testDir) throws Exception {
|
||||||
//Ensure that epoch counter is incremented properly on the workers
|
//Ensure that epoch counter is incremented properly on the workers
|
||||||
|
|
||||||
File temp = testDir.toFile();
|
File temp = testDir.resolve("new-dir-" + UUID.randomUUID().toString()).toFile();
|
||||||
|
temp.mkdirs();
|
||||||
|
|
||||||
//TODO this probably won't work everywhere...
|
//TODO this probably won't work everywhere...
|
||||||
String controller = Inet4Address.getLocalHost().getHostAddress();
|
String controller = Inet4Address.getLocalHost().getHostAddress();
|
||||||
|
@ -394,7 +397,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
JavaRDD<String> pathRdd = sc.parallelize(paths);
|
JavaRDD<String> pathRdd = sc.parallelize(paths);
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i = 0; i < 3; i++) {
|
||||||
ThresholdAlgorithm ta = tm.getThresholdAlgorithm();
|
ThresholdAlgorithm ta = tm.getThresholdAlgorithm();
|
||||||
sparkNet.fitPaths(pathRdd);
|
sparkNet.fitPaths(pathRdd);
|
||||||
//Check also that threshold algorithm was updated/averaged
|
//Check also that threshold algorithm was updated/averaged
|
||||||
|
|
|
@ -1,123 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.spark.iterator;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.spark.TaskContext;
|
|
||||||
import org.apache.spark.TaskContextHelper;
|
|
||||||
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
|
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
import org.nd4j.linalg.dataset.callbacks.DataSetCallback;
|
|
||||||
import org.nd4j.linalg.dataset.callbacks.DefaultCallback;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.concurrent.BlockingQueue;
|
|
||||||
import java.util.concurrent.LinkedBlockingQueue;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class SparkADSI extends AsyncDataSetIterator {
|
|
||||||
protected TaskContext context;
|
|
||||||
|
|
||||||
protected SparkADSI() {
|
|
||||||
super();
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkADSI(DataSetIterator baseIterator) {
|
|
||||||
this(baseIterator, 8);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue) {
|
|
||||||
this(iterator, queueSize, queue, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkADSI(DataSetIterator baseIterator, int queueSize) {
|
|
||||||
this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize));
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkADSI(DataSetIterator baseIterator, int queueSize, boolean useWorkspace) {
|
|
||||||
this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkADSI(DataSetIterator baseIterator, int queueSize, boolean useWorkspace, Integer deviceId) {
|
|
||||||
this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace, new DefaultCallback(),
|
|
||||||
deviceId);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkADSI(DataSetIterator baseIterator, int queueSize, boolean useWorkspace, DataSetCallback callback) {
|
|
||||||
this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace, callback);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace) {
|
|
||||||
this(iterator, queueSize, queue, useWorkspace, new DefaultCallback());
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace,
|
|
||||||
DataSetCallback callback) {
|
|
||||||
this(iterator, queueSize, queue, useWorkspace, callback, Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace,
|
|
||||||
DataSetCallback callback, Integer deviceId) {
|
|
||||||
this();
|
|
||||||
|
|
||||||
if (queueSize < 2)
|
|
||||||
queueSize = 2;
|
|
||||||
|
|
||||||
this.deviceId = deviceId;
|
|
||||||
this.callback = callback;
|
|
||||||
this.useWorkspace = useWorkspace;
|
|
||||||
this.buffer = queue;
|
|
||||||
this.prefetchSize = queueSize;
|
|
||||||
this.backedIterator = iterator;
|
|
||||||
this.workspaceId = "SADSI_ITER-" + java.util.UUID.randomUUID().toString();
|
|
||||||
|
|
||||||
if (iterator.resetSupported())
|
|
||||||
this.backedIterator.reset();
|
|
||||||
|
|
||||||
context = TaskContext.get();
|
|
||||||
|
|
||||||
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null, Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
|
||||||
|
|
||||||
/**
|
|
||||||
* We want to ensure, that background thread will have the same thread->device affinity, as master thread
|
|
||||||
*/
|
|
||||||
|
|
||||||
thread.setDaemon(true);
|
|
||||||
thread.start();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void externalCall() {
|
|
||||||
TaskContextHelper.setTaskContext(context);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public class SparkPrefetchThread extends AsyncPrefetchThread {
|
|
||||||
|
|
||||||
protected SparkPrefetchThread(BlockingQueue<DataSet> queue, DataSetIterator iterator, DataSet terminator, MemoryWorkspace workspace, int deviceId) {
|
|
||||||
super(queue, iterator, terminator, workspace, deviceId);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,118 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.spark.iterator;
|
|
||||||
|
|
||||||
import lombok.NonNull;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.spark.TaskContext;
|
|
||||||
import org.apache.spark.TaskContextHelper;
|
|
||||||
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
|
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
|
||||||
import org.nd4j.linalg.dataset.callbacks.DataSetCallback;
|
|
||||||
import org.nd4j.linalg.dataset.callbacks.DefaultCallback;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.concurrent.BlockingQueue;
|
|
||||||
import java.util.concurrent.LinkedBlockingQueue;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class SparkAMDSI extends AsyncMultiDataSetIterator {
|
|
||||||
protected TaskContext context;
|
|
||||||
|
|
||||||
protected SparkAMDSI() {
|
|
||||||
super();
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkAMDSI(MultiDataSetIterator baseIterator) {
|
|
||||||
this(baseIterator, 8);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue) {
|
|
||||||
this(iterator, queueSize, queue, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize) {
|
|
||||||
this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize));
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace) {
|
|
||||||
this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize), useWorkspace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace, Integer deviceId) {
|
|
||||||
this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize), useWorkspace,
|
|
||||||
new DefaultCallback(), deviceId);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace,
|
|
||||||
DataSetCallback callback) {
|
|
||||||
this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize), useWorkspace, callback);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue,
|
|
||||||
boolean useWorkspace) {
|
|
||||||
this(iterator, queueSize, queue, useWorkspace, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue,
|
|
||||||
boolean useWorkspace, DataSetCallback callback) {
|
|
||||||
this(iterator, queueSize, queue, useWorkspace, callback, Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
|
||||||
}
|
|
||||||
|
|
||||||
public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue,
|
|
||||||
boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
|
|
||||||
this();
|
|
||||||
|
|
||||||
if (queueSize < 2)
|
|
||||||
queueSize = 2;
|
|
||||||
|
|
||||||
this.callback = callback;
|
|
||||||
this.buffer = queue;
|
|
||||||
this.backedIterator = iterator;
|
|
||||||
this.useWorkspaces = useWorkspace;
|
|
||||||
this.prefetchSize = queueSize;
|
|
||||||
this.workspaceId = "SAMDSI_ITER-" + java.util.UUID.randomUUID().toString();
|
|
||||||
this.deviceId = deviceId;
|
|
||||||
|
|
||||||
if (iterator.resetSupported())
|
|
||||||
this.backedIterator.reset();
|
|
||||||
|
|
||||||
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
|
||||||
|
|
||||||
context = TaskContext.get();
|
|
||||||
|
|
||||||
thread.setDaemon(true);
|
|
||||||
thread.start();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void externalCall() {
|
|
||||||
TaskContextHelper.setTaskContext(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected class SparkPrefetchThread extends AsyncPrefetchThread {
|
|
||||||
|
|
||||||
protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue, @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator, int deviceId) {
|
|
||||||
super(queue, iterator, terminator, deviceId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -64,10 +64,6 @@ public class TestEarlyStoppingSpark extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEarlyStoppingIris() {
|
public void testEarlyStoppingIris() {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(new Sgd()).weightInit(WeightInit.XAVIER).list()
|
.updater(new Sgd()).weightInit(WeightInit.XAVIER).list()
|
||||||
|
|
|
@ -67,10 +67,6 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEarlyStoppingIris() {
|
public void testEarlyStoppingIris() {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in")
|
.updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in")
|
||||||
|
|
|
@ -76,10 +76,6 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDataVecDataSetFunction(@TempDir Path testDir) throws Exception {
|
public void testDataVecDataSetFunction(@TempDir Path testDir) throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
|
|
||||||
File f = testDir.toFile();
|
File f = testDir.toFile();
|
||||||
|
|
|
@ -51,10 +51,6 @@ public class TestExport extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBatchAndExportDataSetsFunction() throws Exception {
|
public void testBatchAndExportDataSetsFunction() throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
String baseDir = System.getProperty("java.io.tmpdir");
|
String baseDir = System.getProperty("java.io.tmpdir");
|
||||||
baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExport/");
|
baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExport/");
|
||||||
baseDir = baseDir.replaceAll("\\\\", "/");
|
baseDir = baseDir.replaceAll("\\\\", "/");
|
||||||
|
|
|
@ -70,10 +70,6 @@ public class TestPreProcessedData extends BaseSparkTest {
|
||||||
@Test
|
@Test
|
||||||
public void testPreprocessedData() {
|
public void testPreprocessedData() {
|
||||||
//Test _loading_ of preprocessed data
|
//Test _loading_ of preprocessed data
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 10;
|
int batchSizePerExecutor = 10;
|
||||||
|
|
||||||
|
|
|
@ -52,10 +52,6 @@ public class TestCustomLayer extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSparkWithCustomLayer() {
|
public void testSparkWithCustomLayer() {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
//Basic test - checks whether exceptions etc are thrown with custom layers + spark
|
//Basic test - checks whether exceptions etc are thrown with custom layers + spark
|
||||||
//Custom layers are tested more extensively in dl4j core
|
//Custom layers are tested more extensively in dl4j core
|
||||||
MultiLayerConfiguration conf =
|
MultiLayerConfiguration conf =
|
||||||
|
|
|
@ -77,10 +77,6 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEvaluationSimple() throws Exception {
|
public void testEvaluationSimple() throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
for( int evalWorkers : new int[]{1, 4, 8}) {
|
for( int evalWorkers : new int[]{1, 4, 8}) {
|
||||||
|
|
|
@ -61,10 +61,6 @@ public class TestTrainingStatsCollection extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStatsCollection() throws Exception {
|
public void testStatsCollection() throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
int nWorkers = numExecutors();
|
int nWorkers = numExecutors();
|
||||||
|
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
|
|
|
@ -60,10 +60,6 @@ public class TestListeners extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStatsCollection() {
|
public void testStatsCollection() {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
JavaSparkContext sc = getContext();
|
JavaSparkContext sc = getContext();
|
||||||
int nExecutors = numExecutors();
|
int nExecutors = numExecutors();
|
||||||
|
|
||||||
|
|
|
@ -54,10 +54,6 @@ public class TestRepartitioning extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRepartitioning() {
|
public void testRepartitioning() {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
List<String> list = new ArrayList<>();
|
List<String> list = new ArrayList<>();
|
||||||
for (int i = 0; i < 1000; i++) {
|
for (int i = 0; i < 1000; i++) {
|
||||||
list.add(String.valueOf(i));
|
list.add(String.valueOf(i));
|
||||||
|
|
|
@ -52,10 +52,6 @@ public class TestValidation extends BaseSparkTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDataSetValidation(@TempDir Path folder) throws Exception {
|
public void testDataSetValidation(@TempDir Path folder) throws Exception {
|
||||||
if(Platform.isWindows()) {
|
|
||||||
//Spark tests don't run on windows
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
File f = folder.toFile();
|
File f = folder.toFile();
|
||||||
|
|
||||||
for( int i = 0; i < 3; i++ ) {
|
for( int i = 0; i < 3; i++ ) {
|
||||||
|
|
|
@ -38,11 +38,17 @@
|
||||||
<module>deeplearning4j-ui</module>
|
<module>deeplearning4j-ui</module>
|
||||||
<module>deeplearning4j-ui-components</module>
|
<module>deeplearning4j-ui-components</module>
|
||||||
<module>deeplearning4j-ui-model</module>
|
<module>deeplearning4j-ui-model</module>
|
||||||
<module>deeplearning4j-ui-standalone</module>
|
|
||||||
<module>deeplearning4j-vertx</module>
|
<module>deeplearning4j-vertx</module>
|
||||||
</modules>
|
</modules>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
<profile>
|
||||||
|
<id>ui-jar</id>
|
||||||
|
<modules>
|
||||||
|
<module>deeplearning4j-ui-standalone</module>
|
||||||
|
</modules>
|
||||||
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>nd4j-tests-cpu</id>
|
<id>nd4j-tests-cpu</id>
|
||||||
</profile>
|
</profile>
|
||||||
|
|
|
@ -41,6 +41,7 @@ import java.io.File;
|
||||||
@Tag(TagNames.DL4J_OLD_API)
|
@Tag(TagNames.DL4J_OLD_API)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public class MiscTests extends BaseDL4JTest {
|
public class MiscTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -51,6 +51,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
@Tag(TagNames.DL4J_OLD_API)
|
@Tag(TagNames.DL4J_OLD_API)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public class TestDownload extends BaseDL4JTest {
|
public class TestDownload extends BaseDL4JTest {
|
||||||
@TempDir
|
@TempDir
|
||||||
static Path sharedTempDir;
|
static Path sharedTempDir;
|
||||||
|
|
|
@ -61,6 +61,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
@Tag(TagNames.DL4J_OLD_API)
|
@Tag(TagNames.DL4J_OLD_API)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public class TestImageNet extends BaseDL4JTest {
|
public class TestImageNet extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -2308,7 +2308,7 @@ public class Nd4j {
|
||||||
data2.add(readSplit(data));
|
data2.add(readSplit(data));
|
||||||
}
|
}
|
||||||
float[][] fArr = new float[data2.size()][0];
|
float[][] fArr = new float[data2.size()][0];
|
||||||
for(int i=0; i<data2.size(); i++ ){
|
for(int i = 0; i < data2.size(); i++) {
|
||||||
fArr[i] = data2.get(i);
|
fArr[i] = data2.get(i);
|
||||||
}
|
}
|
||||||
ret = Nd4j.createFromArray(fArr).castTo(dataType);
|
ret = Nd4j.createFromArray(fArr).castTo(dataType);
|
||||||
|
@ -2785,7 +2785,7 @@ public class Nd4j {
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(@NonNull int... shape) {
|
public static INDArray rand(@NonNull int... shape) {
|
||||||
INDArray ret = createUninitialized(shape, order()).castTo(Nd4j.defaultFloatingPointType()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2793,7 +2793,7 @@ public class Nd4j {
|
||||||
* See {@link #rand(int[])}
|
* See {@link #rand(int[])}
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(@NonNull long... shape) {
|
public static INDArray rand(@NonNull long... shape) {
|
||||||
INDArray ret = createUninitialized(shape, order()).castTo(Nd4j.defaultFloatingPointType()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2806,7 +2806,7 @@ public class Nd4j {
|
||||||
public static INDArray rand(@NonNull DataType dataType, @NonNull long... shape) {
|
public static INDArray rand(@NonNull DataType dataType, @NonNull long... shape) {
|
||||||
Preconditions.checkArgument(dataType.isFPType(),
|
Preconditions.checkArgument(dataType.isFPType(),
|
||||||
"Can't create a random array of a non-floating point data type");
|
"Can't create a random array of a non-floating point data type");
|
||||||
INDArray ret = createUninitialized(dataType, shape, order()).castTo(Nd4j.defaultFloatingPointType()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
INDArray ret = createUninitialized(dataType, shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom());
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2820,7 +2820,7 @@ public class Nd4j {
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(char order, @NonNull int... shape) {
|
public static INDArray rand(char order, @NonNull int... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(shape, order).castTo(Nd4j.defaultFloatingPointType()); //INSTANCE.rand(order, shape);
|
INDArray ret = Nd4j.createUninitialized(shape, order); //INSTANCE.rand(order, shape);
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2829,7 +2829,7 @@ public class Nd4j {
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) {
|
public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) {
|
||||||
return rand(dataType, order, ArrayUtil.toLongArray(shape)).castTo(Nd4j.defaultFloatingPointType());
|
return rand(dataType, order, ArrayUtil.toLongArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2837,7 +2837,7 @@ public class Nd4j {
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public static INDArray rand(@NonNull DataType dataType, char order, @NonNull int... shape) {
|
public static INDArray rand(@NonNull DataType dataType, char order, @NonNull int... shape) {
|
||||||
return rand(dataType, order, ArrayUtil.toLongArray(shape)).castTo(Nd4j.defaultFloatingPointType());
|
return rand(dataType, order, ArrayUtil.toLongArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2851,7 +2851,7 @@ public class Nd4j {
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) {
|
public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(dataType, shape, order).castTo(Nd4j.defaultFloatingPointType());
|
INDArray ret = Nd4j.createUninitialized(dataType, shape, order);
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2866,7 +2866,7 @@ public class Nd4j {
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) {
|
public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) {
|
||||||
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()).castTo(Nd4j.defaultFloatingPointType());
|
INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order());
|
||||||
return rand(ret);
|
return rand(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2911,7 +2911,7 @@ public class Nd4j {
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(long seed, @NonNull long... shape) {
|
public static INDArray rand(long seed, @NonNull long... shape) {
|
||||||
INDArray ret = createUninitialized(shape, Nd4j.order()).castTo(Nd4j.defaultFloatingPointType());//;INSTANCE.rand(shape, seed);
|
INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed);
|
||||||
return rand(ret, seed);
|
return rand(ret, seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2920,7 +2920,7 @@ public class Nd4j {
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public static INDArray rand(int[] shape, long seed) {
|
public static INDArray rand(int[] shape, long seed) {
|
||||||
return rand(seed, ArrayUtil.toLongArray(shape)).castTo(Nd4j.defaultFloatingPointType());
|
return rand(seed, ArrayUtil.toLongArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2943,7 +2943,7 @@ public class Nd4j {
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public static INDArray rand(int[] shape, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
public static INDArray rand(int[] shape, @NonNull org.nd4j.linalg.api.rng.Random rng) {
|
||||||
return rand(rng, ArrayUtil.toLongArray(shape)).castTo(Nd4j.defaultFloatingPointType());
|
return rand(rng, ArrayUtil.toLongArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2954,7 +2954,7 @@ public class Nd4j {
|
||||||
* @return the random ndarray with the specified shape
|
* @return the random ndarray with the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray rand(@NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) {
|
public static INDArray rand(@NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) {
|
||||||
INDArray ret = createUninitialized(shape, Nd4j.order()).castTo(Nd4j.defaultFloatingPointType()); //INSTANCE.rand(shape, rng);
|
INDArray ret = createUninitialized(shape, Nd4j.order()); //INSTANCE.rand(shape, rng);
|
||||||
return rand(ret, rng);
|
return rand(ret, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2963,7 +2963,7 @@ public class Nd4j {
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public static INDArray rand(int[] shape, @NonNull Distribution dist) {
|
public static INDArray rand(int[] shape, @NonNull Distribution dist) {
|
||||||
return rand(dist, ArrayUtil.toLongArray(shape)).castTo(Nd4j.defaultFloatingPointType());
|
return rand(dist, ArrayUtil.toLongArray(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -75,7 +75,8 @@ public class BinarySerde {
|
||||||
ByteBuffer byteBuffer = buffer.hasArray() ? ByteBuffer.allocateDirect(buffer.array().length).put(buffer.array())
|
ByteBuffer byteBuffer = buffer.hasArray() ? ByteBuffer.allocateDirect(buffer.array().length).put(buffer.array())
|
||||||
.order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder());
|
.order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder());
|
||||||
//bump the byte buffer to the proper position
|
//bump the byte buffer to the proper position
|
||||||
byteBuffer.position(offset);
|
Buffer buffer1 = (Buffer) byteBuffer;
|
||||||
|
buffer1.position(offset);
|
||||||
int rank = byteBuffer.getInt();
|
int rank = byteBuffer.getInt();
|
||||||
if (rank < 0)
|
if (rank < 0)
|
||||||
throw new IllegalStateException("Found negative integer. Corrupt serialization?");
|
throw new IllegalStateException("Found negative integer. Corrupt serialization?");
|
||||||
|
@ -99,7 +100,8 @@ public class BinarySerde {
|
||||||
DataBuffer buff = Nd4j.createBuffer(slice, type, (int) Shape.length(shapeBuff));
|
DataBuffer buff = Nd4j.createBuffer(slice, type, (int) Shape.length(shapeBuff));
|
||||||
//advance past the data
|
//advance past the data
|
||||||
int position = byteBuffer.position() + (buff.getElementSize() * (int) buff.length());
|
int position = byteBuffer.position() + (buff.getElementSize() * (int) buff.length());
|
||||||
byteBuffer.position(position);
|
Buffer buffer2 = (Buffer) byteBuffer;
|
||||||
|
buffer2.position(position);
|
||||||
//create the final array
|
//create the final array
|
||||||
//TODO: see how to avoid dup here
|
//TODO: see how to avoid dup here
|
||||||
INDArray arr = Nd4j.createArrayFromShapeBuffer(buff.dup(), shapeBuff.dup());
|
INDArray arr = Nd4j.createArrayFromShapeBuffer(buff.dup(), shapeBuff.dup());
|
||||||
|
@ -116,7 +118,8 @@ public class BinarySerde {
|
||||||
INDArray arr = Nd4j.createArrayFromShapeBuffer(compressedDataBuffer.dup(), shapeBuff.dup());
|
INDArray arr = Nd4j.createArrayFromShapeBuffer(compressedDataBuffer.dup(), shapeBuff.dup());
|
||||||
//advance past the data
|
//advance past the data
|
||||||
int compressLength = (int) compressionDescriptor.getCompressedLength();
|
int compressLength = (int) compressionDescriptor.getCompressedLength();
|
||||||
byteBuffer.position(byteBuffer.position() + compressLength);
|
Buffer buffer2 = (Buffer) byteBuffer;
|
||||||
|
buffer2.position(buffer2.position() + compressLength);
|
||||||
return Pair.of(arr, byteBuffer);
|
return Pair.of(arr, byteBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -140,6 +140,7 @@
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
<configuration>
|
||||||
|
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
|
||||||
<forkCount>${cpu.core.count}</forkCount>
|
<forkCount>${cpu.core.count}</forkCount>
|
||||||
<reuseForks>false</reuseForks>
|
<reuseForks>false</reuseForks>
|
||||||
<environmentVariables>
|
<environmentVariables>
|
||||||
|
@ -162,7 +163,7 @@
|
||||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
||||||
Depending on a build machine, default value is not always enough.
|
Depending on a build machine, default value is not always enough.
|
||||||
-->
|
-->
|
||||||
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g</argLine>
|
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
|
|
|
@ -526,7 +526,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
public INDArray toFlattened(char order, Collection<INDArray> matrices) {
|
public INDArray toFlattened(char order, Collection<INDArray> matrices) {
|
||||||
Preconditions.checkArgument(matrices.size() > 0, "toFlattened expects > 0 operands");
|
Preconditions.checkArgument(matrices.size() > 0, "toFlattened expects > 0 operands");
|
||||||
|
|
||||||
return Nd4j.exec(new Flatten(order, matrices.toArray(new INDArray[matrices.size()])))[0];
|
return Nd4j.exec(new Flatten(order, matrices.toArray(new INDArray[matrices.size()])))[0]
|
||||||
|
.castTo(matrices.iterator().next().dataType());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -124,6 +124,7 @@
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
<configuration>
|
||||||
|
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
|
||||||
<forkCount>${cpu.core.count}</forkCount>
|
<forkCount>${cpu.core.count}</forkCount>
|
||||||
<reuseForks>false</reuseForks>
|
<reuseForks>false</reuseForks>
|
||||||
<environmentVariables>
|
<environmentVariables>
|
||||||
|
@ -139,7 +140,12 @@
|
||||||
Maximum heap size was set to 8g, as a minimum required value for tests run.
|
Maximum heap size was set to 8g, as a minimum required value for tests run.
|
||||||
Depending on a build machine, default value is not always enough.
|
Depending on a build machine, default value is not always enough.
|
||||||
-->
|
-->
|
||||||
<argLine>-Xmx2g</argLine>
|
<argLine>-Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine>
|
||||||
|
<forkedProcessTimeoutInSeconds>240</forkedProcessTimeoutInSeconds>
|
||||||
|
<forkedProcessExitTimeoutInSeconds>240</forkedProcessExitTimeoutInSeconds>
|
||||||
|
<parallelTestsTimeoutInSeconds>240</parallelTestsTimeoutInSeconds>
|
||||||
|
<parallelTestsTimeoutForcedInSeconds>240</parallelTestsTimeoutForcedInSeconds>
|
||||||
|
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
|
|
|
@ -269,6 +269,7 @@
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
<configuration>
|
||||||
|
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
|
||||||
<forkCount>${cpu.core.count}</forkCount>
|
<forkCount>${cpu.core.count}</forkCount>
|
||||||
<reuseForks>false</reuseForks>
|
<reuseForks>false</reuseForks>
|
||||||
<environmentVariables>
|
<environmentVariables>
|
||||||
|
@ -304,7 +305,7 @@
|
||||||
|
|
||||||
For testing large zoo models, this may not be enough (so comment it out).
|
For testing large zoo models, this may not be enough (so comment it out).
|
||||||
-->
|
-->
|
||||||
<argLine>-Dfile.encoding=UTF-8 </argLine>
|
<argLine>-Dfile.encoding=UTF-8 -Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
@ -350,6 +351,7 @@
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
<configuration>
|
<configuration>
|
||||||
|
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
|
||||||
<environmentVariables>
|
<environmentVariables>
|
||||||
<LD_LIBRARY_PATH>
|
<LD_LIBRARY_PATH>
|
||||||
${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes
|
${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes
|
||||||
|
@ -379,7 +381,12 @@
|
||||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
||||||
Depending on a build machine, default value is not always enough.
|
Depending on a build machine, default value is not always enough.
|
||||||
-->
|
-->
|
||||||
<argLine> -Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
<argLine>-Dfile.encoding=UTF-8 -Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine>
|
||||||
|
<forkedProcessTimeoutInSeconds>240</forkedProcessTimeoutInSeconds>
|
||||||
|
<forkedProcessExitTimeoutInSeconds>240</forkedProcessExitTimeoutInSeconds>
|
||||||
|
<parallelTestsTimeoutInSeconds>240</parallelTestsTimeoutInSeconds>
|
||||||
|
<parallelTestsTimeoutForcedInSeconds>240</parallelTestsTimeoutForcedInSeconds>
|
||||||
|
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
|
|
@ -216,6 +216,8 @@ public class TestSessions extends BaseNd4jTestWithBackends {
|
||||||
@Tag(TagNames.FILE_IO)
|
@Tag(TagNames.FILE_IO)
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public void testSwitchWhile(Nd4jBackend backend) throws Exception{
|
public void testSwitchWhile(Nd4jBackend backend) throws Exception{
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -94,7 +94,6 @@ import org.nd4j.weightinit.impl.UniformInitScheme;
|
||||||
@Tag(TagNames.SAMEDIFF)
|
@Tag(TagNames.SAMEDIFF)
|
||||||
public class SameDiffTests extends BaseNd4jTestWithBackends {
|
public class SameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
private DataType initialType;
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -112,16 +111,11 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void before() {
|
public void before() {
|
||||||
Nd4j.create(1);
|
Nd4j.create(1);
|
||||||
initialType = Nd4j.dataType();
|
|
||||||
|
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
|
||||||
Nd4j.getRandom().setSeed(123);
|
Nd4j.getRandom().setSeed(123);
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
public void after() {
|
public void after() {
|
||||||
Nd4j.setDataType(initialType);
|
|
||||||
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
|
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
|
||||||
}
|
}
|
||||||
|
@ -136,7 +130,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
INDArray labels = Nd4j.create(new double[]{1, 1, 0, 1}).reshape(4, 1);
|
INDArray labels = Nd4j.create(new double[]{1, 1, 0, 1}).reshape(4, 1);
|
||||||
|
|
||||||
INDArray weights = Nd4j.zeros(3, 1);
|
INDArray weights = Nd4j.zeros(3, 1).castTo(labels.dataType());
|
||||||
|
|
||||||
Map<String, INDArray> inputMap = new HashMap<>();
|
Map<String, INDArray> inputMap = new HashMap<>();
|
||||||
inputMap.put("x", inputs);
|
inputMap.put("x", inputs);
|
||||||
|
@ -155,7 +149,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
val nodeA = sd.math().square(input);
|
val nodeA = sd.math().square(input);
|
||||||
val nodeB = sd.math().square(nodeA);
|
val nodeB = sd.math().square(nodeA);
|
||||||
|
|
||||||
sd.associateArrayWithVariable(Nd4j.create(new double[]{1, 2, 3, 4, 5, 6}, new long[]{2, 3}), input);
|
sd.associateArrayWithVariable(Nd4j.create(new double[]{1, 2, 3, 4, 5, 6}, new long[]{2, 3}).castTo(input.dataType()), input);
|
||||||
|
|
||||||
sd.outputAll(null);
|
sd.outputAll(null);
|
||||||
|
|
||||||
|
@ -2627,7 +2621,8 @@ public class SameDiffTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, 1, 3);
|
SDVariable in = sd.placeHolder("in", DataType.FLOAT, 1, 3);
|
||||||
SDVariable w = sd.constant("w", Nd4j.rand(DataType.FLOAT, 3, 4));
|
INDArray const1 = Nd4j.rand(DataType.FLOAT, 3, 4);
|
||||||
|
SDVariable w = sd.constant("w",const1);
|
||||||
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 4));
|
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 4));
|
||||||
SDVariable mmul = in.mmul(w);
|
SDVariable mmul = in.mmul(w);
|
||||||
SDVariable add = mmul.add(b);
|
SDVariable add = mmul.add(b);
|
||||||
|
|
|
@ -21,13 +21,18 @@
|
||||||
package org.nd4j.autodiff.samediff.listeners;
|
package org.nd4j.autodiff.samediff.listeners;
|
||||||
|
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener;
|
import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
||||||
|
@ -38,10 +43,7 @@ import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.Arrays;
|
import java.util.*;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
@ -169,8 +171,12 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
|
@Disabled("Inconsistent results on output")
|
||||||
|
@Tag(TagNames.NEEDS_VERIFY)
|
||||||
public void testCheckpointListenerEveryTimeUnit(Nd4jBackend backend) throws Exception {
|
public void testCheckpointListenerEveryTimeUnit(Nd4jBackend backend) throws Exception {
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.resolve("new-dir-" + UUID.randomUUID().toString()).toFile();
|
||||||
|
assertTrue(dir.mkdirs());
|
||||||
SameDiff sd = getModel();
|
SameDiff sd = getModel();
|
||||||
|
|
||||||
CheckpointListener l = new CheckpointListener.Builder(dir)
|
CheckpointListener l = new CheckpointListener.Builder(dir)
|
||||||
|
@ -181,9 +187,8 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
DataSetIterator iter = getIter(15, 150);
|
DataSetIterator iter = getIter(15, 150);
|
||||||
|
|
||||||
for(int i=0; i<5; i++ ){ //10 iterations total
|
for(int i = 0; i < 5; i++) { //10 iterations total
|
||||||
sd.fit(iter, 1);
|
sd.fit(iter, 1);
|
||||||
Thread.sleep(5000);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//Expect models saved at iterations: 10, 20, 30, 40
|
//Expect models saved at iterations: 10, 20, 30, 40
|
||||||
|
|
|
@ -123,7 +123,7 @@ public class ListenerTest extends BaseNd4jTestWithBackends {
|
||||||
//
|
//
|
||||||
// sd.evaluateMultiple(iter, evalMap);
|
// sd.evaluateMultiple(iter, evalMap);
|
||||||
|
|
||||||
e = (Evaluation) hist.finalTrainingEvaluations().evaluation(predictions);
|
e = hist.finalTrainingEvaluations().evaluation(predictions);
|
||||||
|
|
||||||
System.out.println(e.stats());
|
System.out.println(e.stats());
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
RegressionEvaluation eval = new RegressionEvaluation(nCols);
|
RegressionEvaluation eval = new RegressionEvaluation(nCols);
|
||||||
|
|
||||||
for (int i = 0; i < nTestArrays; i++) {
|
for (int i = 0; i < nTestArrays; i++) {
|
||||||
INDArray rand = Nd4j.rand(valuesPerTestArray, nCols).castTo(DataType.DOUBLE);
|
INDArray rand = Nd4j.rand(DataType.DOUBLE,valuesPerTestArray, nCols);
|
||||||
eval.eval(rand, rand);
|
eval.eval(rand, rand);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,8 +172,8 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
for (int i = 0; i < nEvalInstances; i++) {
|
for (int i = 0; i < nEvalInstances; i++) {
|
||||||
list.add(new RegressionEvaluation(nCols));
|
list.add(new RegressionEvaluation(nCols));
|
||||||
for (int j = 0; j < numMinibatches; j++) {
|
for (int j = 0; j < numMinibatches; j++) {
|
||||||
INDArray p = Nd4j.rand(nRows, nCols).castTo(Nd4j.defaultFloatingPointType());
|
INDArray p = Nd4j.rand(DataType.DOUBLE,nRows, nCols);
|
||||||
INDArray act = Nd4j.rand(nRows, nCols).castTo(Nd4j.defaultFloatingPointType());
|
INDArray act = Nd4j.rand(DataType.DOUBLE,nRows, nCols);
|
||||||
|
|
||||||
single.eval(act, p);
|
single.eval(act, p);
|
||||||
|
|
||||||
|
@ -383,7 +383,7 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
List<INDArray> rowsL = new ArrayList<>();
|
List<INDArray> rowsL = new ArrayList<>();
|
||||||
|
|
||||||
//Check per-example masking:
|
//Check per-example masking:
|
||||||
INDArray mask1dPerEx = Nd4j.createFromArray(1, 0);
|
INDArray mask1dPerEx = Nd4j.createFromArray(1, 0).castTo(DataType.FLOAT);
|
||||||
|
|
||||||
NdIndexIterator iter = new NdIndexIterator(2, 10, 10);
|
NdIndexIterator iter = new NdIndexIterator(2, 10, 10);
|
||||||
while (iter.hasNext()) {
|
while (iter.hasNext()) {
|
||||||
|
@ -409,7 +409,7 @@ public class RegressionEvalTest extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
//Check per-output masking:
|
//Check per-output masking:
|
||||||
INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape());
|
INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape()).castTo(DataType.FLOAT);
|
||||||
rowsP.clear();
|
rowsP.clear();
|
||||||
rowsL.clear();
|
rowsL.clear();
|
||||||
List<INDArray> rowsM = new ArrayList<>();
|
List<INDArray> rowsM = new ArrayList<>();
|
||||||
|
|
|
@ -24,6 +24,8 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
@ -45,18 +47,15 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
|
||||||
private final int THREADS = 16;
|
private final int THREADS = 16;
|
||||||
private final int LENGTH = 51200 * 4;
|
private final int LENGTH = 51200 * 4;
|
||||||
|
|
||||||
DataType initialType = Nd4j.dataType();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
public void shutUp() {
|
public void shutUp() {
|
||||||
DataTypeUtil.setDTypeForContext(initialType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,6 +110,7 @@ public class AveragingTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
public void testSingleDeviceAveraging2(Nd4jBackend backend) {
|
public void testSingleDeviceAveraging2(Nd4jBackend backend) {
|
||||||
INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH);
|
INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH);
|
||||||
List<INDArray> arrays = new ArrayList<>();
|
List<INDArray> arrays = new ArrayList<>();
|
||||||
|
|
|
@ -23,11 +23,13 @@ package org.nd4j.linalg;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.lang3.RandomUtils;
|
import org.apache.commons.lang3.RandomUtils;
|
||||||
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
|
@ -239,6 +241,8 @@ public class LoneTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
public void testGetRow1(Nd4jBackend backend) {
|
public void testGetRow1(Nd4jBackend backend) {
|
||||||
INDArray array = Nd4j.create(10000, 10000);
|
INDArray array = Nd4j.create(10000, 10000);
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,9 @@ import org.apache.commons.math3.util.FastMath;
|
||||||
import org.junit.jupiter.api.*;
|
import org.junit.jupiter.api.*;
|
||||||
|
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
|
import org.junit.jupiter.api.parallel.Isolated;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
@ -148,8 +151,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Tag(TagNames.FILE_IO)
|
@Tag(TagNames.FILE_IO)
|
||||||
public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
DataType initialType = Nd4j.dataType();
|
|
||||||
Level1 l1 = Nd4j.getBlasWrapper().level1();
|
|
||||||
@TempDir Path testDir;
|
@TempDir Path testDir;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -159,7 +160,6 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void before() throws Exception {
|
public void before() throws Exception {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
|
||||||
Nd4j.getRandom().setSeed(123);
|
Nd4j.getRandom().setSeed(123);
|
||||||
Nd4j.getExecutioner().enableDebugMode(false);
|
Nd4j.getExecutioner().enableDebugMode(false);
|
||||||
Nd4j.getExecutioner().enableVerboseMode(false);
|
Nd4j.getExecutioner().enableVerboseMode(false);
|
||||||
|
@ -167,7 +167,6 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
public void after() throws Exception {
|
public void after() throws Exception {
|
||||||
Nd4j.setDataType(initialType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -1480,7 +1479,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}, new int[]{12});
|
INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}, new int[]{12});
|
||||||
INDArray flattened = Nd4j.toFlattened(concat);
|
INDArray flattened = Nd4j.toFlattened(concat).castTo(assertion.dataType());
|
||||||
assertEquals(assertion, flattened);
|
assertEquals(assertion, flattened);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3902,6 +3901,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled("Crashes")
|
||||||
|
@Tag(TagNames.NEEDS_VERIFY)
|
||||||
public void testSingleDeviceAveraging(Nd4jBackend backend) {
|
public void testSingleDeviceAveraging(Nd4jBackend backend) {
|
||||||
int LENGTH = 512 * 1024 * 2;
|
int LENGTH = 512 * 1024 * 2;
|
||||||
INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0);
|
INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0);
|
||||||
|
@ -5587,6 +5588,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled("Crashes")
|
||||||
|
@Tag(TagNames.NEEDS_VERIFY)
|
||||||
public void testNativeSort3(Nd4jBackend backend) {
|
public void testNativeSort3(Nd4jBackend backend) {
|
||||||
int length = isIntegrationTests() ? 1048576 : 16484;
|
int length = isIntegrationTests() ? 1048576 : 16484;
|
||||||
INDArray array = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape(1, -1);
|
INDArray array = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape(1, -1);
|
||||||
|
@ -5719,6 +5722,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled("Crashes")
|
||||||
|
@Tag(TagNames.NEEDS_VERIFY)
|
||||||
public void testNativeSortAlongDimension1(Nd4jBackend backend) {
|
public void testNativeSortAlongDimension1(Nd4jBackend backend) {
|
||||||
INDArray array = Nd4j.create(1000, 1000);
|
INDArray array = Nd4j.create(1000, 1000);
|
||||||
INDArray exp1 = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE);
|
INDArray exp1 = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE);
|
||||||
|
@ -5779,6 +5784,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled("Crashes")
|
||||||
|
@Tag(TagNames.NEEDS_VERIFY)
|
||||||
public void testNativeSortAlongDimension3(Nd4jBackend backend) {
|
public void testNativeSortAlongDimension3(Nd4jBackend backend) {
|
||||||
INDArray array = Nd4j.create(2000, 2000);
|
INDArray array = Nd4j.create(2000, 2000);
|
||||||
INDArray exp1 = Nd4j.linspace(1, 2000, 2000, DataType.DOUBLE);
|
INDArray exp1 = Nd4j.linspace(1, 2000, 2000, DataType.DOUBLE);
|
||||||
|
@ -5814,6 +5821,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled("Crashes")
|
||||||
|
@Tag(TagNames.NEEDS_VERIFY)
|
||||||
public void testNativeSortAlongDimension2(Nd4jBackend backend) {
|
public void testNativeSortAlongDimension2(Nd4jBackend backend) {
|
||||||
INDArray array = Nd4j.create(100, 10);
|
INDArray array = Nd4j.create(100, 10);
|
||||||
INDArray exp1 = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
INDArray exp1 = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
||||||
|
@ -6768,15 +6777,16 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInconsistentOutput(){
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
|
public void testInconsistentOutput(Nd4jBackend backend) {
|
||||||
INDArray in = Nd4j.rand(1, 802816).castTo(DataType.DOUBLE);
|
INDArray in = Nd4j.rand(1, 802816).castTo(DataType.DOUBLE);
|
||||||
INDArray W = Nd4j.rand(802816, 1).castTo(DataType.DOUBLE);
|
INDArray W = Nd4j.rand(802816, 1).castTo(DataType.DOUBLE);
|
||||||
INDArray b = Nd4j.create(1).castTo(DataType.DOUBLE);
|
INDArray b = Nd4j.create(1).castTo(DataType.DOUBLE);
|
||||||
INDArray out = fwd(in, W, b);
|
INDArray out = fwd(in, W, b);
|
||||||
|
|
||||||
for(int i = 0;i < 100;i++) {
|
for(int i = 0; i < 100;i++) {
|
||||||
INDArray out2 = fwd(in, W, b); //l.activate(inToLayer1, false, LayerWorkspaceMgr.noWorkspaces());
|
INDArray out2 = fwd(in, W, b); //l.activate(inToLayer1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertEquals( out, out2,"Failed at iteration [" + String.valueOf(i) + "]");
|
assertEquals( out, out2,"Failed at iteration [" + i + "]");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7144,9 +7154,10 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRowColumnOpsRank1(){
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
|
public void testRowColumnOpsRank1(Nd4jBackend backend) {
|
||||||
|
|
||||||
for( int i=0; i<6; i++ ) {
|
for( int i = 0; i < 6; i++ ) {
|
||||||
INDArray orig = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
|
INDArray orig = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4);
|
||||||
INDArray in1r = orig.dup();
|
INDArray in1r = orig.dup();
|
||||||
INDArray in2r = orig.dup();
|
INDArray in2r = orig.dup();
|
||||||
|
@ -7954,6 +7965,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled("Crashes")
|
||||||
|
@Tag(TagNames.NEEDS_VERIFY)
|
||||||
public void testRollingMean(Nd4jBackend backend) {
|
public void testRollingMean(Nd4jBackend backend) {
|
||||||
val wsconf = WorkspaceConfiguration.builder()
|
val wsconf = WorkspaceConfiguration.builder()
|
||||||
.initialSize(4L * (32*128*256*256 + 32*128 + 10*1024*1024))
|
.initialSize(4L * (32*128*256*256 + 32*128 + 10*1024*1024))
|
||||||
|
@ -8558,8 +8571,6 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled("Needs verification")
|
|
||||||
@Tag(TagNames.NEEDS_VERIFY)
|
|
||||||
public void testBatchToSpace(Nd4jBackend backend) {
|
public void testBatchToSpace(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray out = Nd4j.create(DataType.FLOAT, 2, 4, 5);
|
INDArray out = Nd4j.create(DataType.FLOAT, 2, 4, 5);
|
||||||
|
@ -8833,7 +8844,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCreateBufferFromByteBuffer(){
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
public void testCreateBufferFromByteBuffer(Nd4jBackend backend){
|
||||||
|
|
||||||
for(DataType dt : DataType.values()){
|
for(DataType dt : DataType.values()){
|
||||||
if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
|
if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
|
||||||
|
|
|
@ -41,6 +41,7 @@ import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
|
|
||||||
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
|
|
||||||
|
@ -378,7 +379,8 @@ public class DataBufferTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
INDArray arr2 = Nd4j.create(dt, arr.shape());
|
INDArray arr2 = Nd4j.create(dt, arr.shape());
|
||||||
ByteBuffer bb = arr2.data().pointer().asByteBuffer();
|
ByteBuffer bb = arr2.data().pointer().asByteBuffer();
|
||||||
bb.position(0);
|
Buffer buffer = (Buffer) bb;
|
||||||
|
buffer.position(0);
|
||||||
bb.put(b);
|
bb.put(b);
|
||||||
|
|
||||||
Nd4j.getAffinityManager().tagLocation(arr2, AffinityManager.Location.HOST);
|
Nd4j.getAffinityManager().tagLocation(arr2, AffinityManager.Location.HOST);
|
||||||
|
|
|
@ -59,18 +59,15 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@NativeTag
|
@NativeTag
|
||||||
public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
DataType initialType = Nd4j.dataType();
|
|
||||||
@TempDir Path tempDir;
|
@TempDir Path tempDir;
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void before() {
|
public void before() {
|
||||||
DataTypeUtil.setDTypeForContext(DataType.FLOAT);
|
|
||||||
System.out.println("DATATYPE HERE: " + Nd4j.dataType());
|
System.out.println("DATATYPE HERE: " + Nd4j.dataType());
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
public void after() {
|
public void after() {
|
||||||
DataTypeUtil.setDTypeForContext(initialType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -191,7 +188,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAsBytes(Nd4jBackend backend) {
|
public void testAsBytes(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(5);
|
INDArray arr = Nd4j.create(DataType.FLOAT,5);
|
||||||
byte[] d = arr.data().asBytes();
|
byte[] d = arr.data().asBytes();
|
||||||
assertEquals(4 * 5, d.length,getFailureMessage(backend));
|
assertEquals(4 * 5, d.length,getFailureMessage(backend));
|
||||||
INDArray rand = Nd4j.rand(3, 3);
|
INDArray rand = Nd4j.rand(3, 3);
|
||||||
|
@ -245,7 +242,9 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
buffer.reallocate(6);
|
buffer.reallocate(6);
|
||||||
float[] newBuf = buffer.asFloat();
|
float[] newBuf = buffer.asFloat();
|
||||||
assertEquals(6, buffer.capacity());
|
assertEquals(6, buffer.capacity());
|
||||||
assertArrayEquals(old, newBuf, 1e-4F);
|
//note: old and new buf are not equal because java automatically populates the arrays with zeros
|
||||||
|
//the new buffer is actually 1,2,3,4,0,0 because of this
|
||||||
|
assertArrayEquals(new float[]{1,2,3,4,0,0}, newBuf, 1e-4F);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -253,8 +252,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
public void testReallocationWorkspace(Nd4jBackend backend) {
|
public void testReallocationWorkspace(Nd4jBackend backend) {
|
||||||
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
||||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
|
||||||
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
|
try(MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID")) {
|
||||||
|
|
||||||
DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4});
|
DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4});
|
||||||
assertTrue(buffer.isAttached());
|
assertTrue(buffer.isAttached());
|
||||||
float[] old = buffer.asFloat();
|
float[] old = buffer.asFloat();
|
||||||
|
@ -262,8 +260,9 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends {
|
||||||
buffer.reallocate(6);
|
buffer.reallocate(6);
|
||||||
assertEquals(6, buffer.capacity());
|
assertEquals(6, buffer.capacity());
|
||||||
float[] newBuf = buffer.asFloat();
|
float[] newBuf = buffer.asFloat();
|
||||||
assertArrayEquals(old, newBuf, 1e-4F);
|
//note: java creates new zeros by default for empty array spots
|
||||||
workspace.close();
|
assertArrayEquals(new float[]{1,2,3,4,0,0}, newBuf, 1e-4F);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
|
|
@ -175,9 +175,11 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void basicBroadcastFailureTest_4(Nd4jBackend backend) {
|
public void basicBroadcastFailureTest_4(Nd4jBackend backend) {
|
||||||
|
assertThrows(IllegalStateException.class,() -> {
|
||||||
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f);
|
||||||
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2);
|
||||||
val z = x.addi(y);
|
val z = x.addi(y);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
|
|
@ -25,6 +25,9 @@ import lombok.val;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
|
import org.junit.jupiter.api.parallel.Isolated;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
@ -52,6 +55,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@NativeTag
|
@NativeTag
|
||||||
@Tag(TagNames.COMPRESSION)
|
@Tag(TagNames.COMPRESSION)
|
||||||
|
@Isolated
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
public class CompressionTests extends BaseNd4jTestWithBackends {
|
public class CompressionTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
@ -412,9 +417,11 @@ public class CompressionTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public void testBitmapEncoding2(Nd4jBackend backend) {
|
public void testBitmapEncoding2(Nd4jBackend backend) {
|
||||||
INDArray initial = Nd4j.create(40000000);
|
INDArray initial = Nd4j.create(DataType.FLOAT,40000000);
|
||||||
INDArray target = Nd4j.create(initial.length());
|
INDArray target = Nd4j.create(DataType.FLOAT,initial.length());
|
||||||
|
|
||||||
initial.addi(1e-3);
|
initial.addi(1e-3);
|
||||||
|
|
||||||
|
|
|
@ -61,6 +61,7 @@ public class DeconvTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public void compareKeras(Nd4jBackend backend) throws Exception {
|
public void compareKeras(Nd4jBackend backend) throws Exception {
|
||||||
File newFolder = testDir.toFile();
|
File newFolder = testDir.toFile();
|
||||||
new ClassPathResource("keras/deconv/").copyDirectory(newFolder);
|
new ClassPathResource("keras/deconv/").copyDirectory(newFolder);
|
||||||
|
|
|
@ -103,7 +103,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering(){
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -566,7 +566,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStridedSliceEdgeCase(){
|
public void testStridedSliceEdgeCase(Nd4jBackend backend) {
|
||||||
INDArray in = Nd4j.scalar(10.0).reshape(1); //Int [1]
|
INDArray in = Nd4j.scalar(10.0).reshape(1); //Int [1]
|
||||||
INDArray begin = Nd4j.ones(DataType.INT, 1);
|
INDArray begin = Nd4j.ones(DataType.INT, 1);
|
||||||
INDArray end = Nd4j.zeros(DataType.INT, 1);
|
INDArray end = Nd4j.zeros(DataType.INT, 1);
|
||||||
|
@ -595,7 +595,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDepthwise(){
|
public void testDepthwise(Nd4jBackend backend) {
|
||||||
INDArray input = Nd4j.create(DataType.DOUBLE, 1,3,8,8);
|
INDArray input = Nd4j.create(DataType.DOUBLE, 1,3,8,8);
|
||||||
INDArray depthwiseWeight = Nd4j.create(DataType.DOUBLE, 1,1,3,2);
|
INDArray depthwiseWeight = Nd4j.create(DataType.DOUBLE, 1,1,3,2);
|
||||||
INDArray bias = Nd4j.create(DataType.DOUBLE, 1, 6);
|
INDArray bias = Nd4j.create(DataType.DOUBLE, 1, 6);
|
||||||
|
@ -660,8 +660,10 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(e, z);
|
assertEquals(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
public void testInputValidationMergeMax(){
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
public void testInputValidationMergeMax(Nd4jBackend backend) {
|
||||||
assertThrows(RuntimeException.class,() -> {
|
assertThrows(RuntimeException.class,() -> {
|
||||||
INDArray[] inputs = new INDArray[]{
|
INDArray[] inputs = new INDArray[]{
|
||||||
Nd4j.createFromArray(0.0f, 1.0f, 2.0f).reshape('c', 1, 3),
|
Nd4j.createFromArray(0.0f, 1.0f, 2.0f).reshape('c', 1, 3),
|
||||||
|
@ -683,7 +685,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testUpsampling2dBackprop(){
|
public void testUpsampling2dBackprop(Nd4jBackend backend) {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int c = 2;
|
int c = 2;
|
||||||
|
@ -729,7 +731,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testIsMaxView(){
|
public void testIsMaxView(Nd4jBackend backend) {
|
||||||
INDArray predictions = Nd4j.rand(DataType.FLOAT, 3, 4, 3, 2);
|
INDArray predictions = Nd4j.rand(DataType.FLOAT, 3, 4, 3, 2);
|
||||||
|
|
||||||
INDArray row = predictions.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0));
|
INDArray row = predictions.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0));
|
||||||
|
@ -748,7 +750,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void isMax4d_2dims(){
|
public void isMax4d_2dims(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 3, 4, 4).permute(0, 2, 3, 1);
|
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 3, 4, 4).permute(0, 2, 3, 1);
|
||||||
|
|
||||||
|
@ -764,7 +766,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSizeTypes(){
|
public void testSizeTypes(Nd4jBackend backend) {
|
||||||
List<DataType> failed = new ArrayList<>();
|
List<DataType> failed = new ArrayList<>();
|
||||||
for(DataType dt : new DataType[]{DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE,
|
for(DataType dt : new DataType[]{DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE,
|
||||||
DataType.UINT64, DataType.UINT32, DataType.UINT16, DataType.UBYTE,
|
DataType.UINT64, DataType.UINT32, DataType.UINT16, DataType.UBYTE,
|
||||||
|
@ -796,7 +798,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testListDiff(){
|
public void testListDiff(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
|
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
|
||||||
INDArray y = Nd4j.createFromArray(3, 1);
|
INDArray y = Nd4j.createFromArray(3, 1);
|
||||||
|
|
||||||
|
@ -817,7 +819,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTopK1(){
|
public void testTopK1(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
|
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
|
||||||
INDArray k = Nd4j.scalar(1);
|
INDArray k = Nd4j.scalar(1);
|
||||||
INDArray outValue = Nd4j.create(DataType.DOUBLE, 1);
|
INDArray outValue = Nd4j.create(DataType.DOUBLE, 1);
|
||||||
|
@ -897,7 +899,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAdjustContrastShape(){
|
public void testAdjustContrastShape(Nd4jBackend backend) {
|
||||||
DynamicCustomOp op = DynamicCustomOp.builder("adjust_contrast_v2")
|
DynamicCustomOp op = DynamicCustomOp.builder("adjust_contrast_v2")
|
||||||
.addInputs(Nd4j.create(DataType.FLOAT, 256, 256,3), Nd4j.scalar(0.5f))
|
.addInputs(Nd4j.create(DataType.FLOAT, 256, 256,3), Nd4j.scalar(0.5f))
|
||||||
.build();
|
.build();
|
||||||
|
@ -910,7 +912,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBitCastShape(){
|
public void testBitCastShape(Nd4jBackend backend) {
|
||||||
INDArray out = Nd4j.createUninitialized(1,10);
|
INDArray out = Nd4j.createUninitialized(1,10);
|
||||||
BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out);
|
BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out);
|
||||||
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
||||||
|
@ -1148,7 +1150,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRange(){
|
public void testRange(Nd4jBackend backend) {
|
||||||
DynamicCustomOp op = DynamicCustomOp.builder("range")
|
DynamicCustomOp op = DynamicCustomOp.builder("range")
|
||||||
.addFloatingPointArguments(-1.0, 1.0, 0.01)
|
.addFloatingPointArguments(-1.0, 1.0, 0.01)
|
||||||
.build();
|
.build();
|
||||||
|
@ -1163,7 +1165,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBitCastShape_1(){
|
public void testBitCastShape_1(Nd4jBackend backend) {
|
||||||
val out = Nd4j.createUninitialized(1,10);
|
val out = Nd4j.createUninitialized(1,10);
|
||||||
BitCast op = new BitCast(Nd4j.zeros(DataType.FLOAT,1,10), DataType.INT.toInt(), out);
|
BitCast op = new BitCast(Nd4j.zeros(DataType.FLOAT,1,10), DataType.INT.toInt(), out);
|
||||||
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
||||||
|
@ -1174,7 +1176,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBitCastShape_2(){
|
public void testBitCastShape_2(Nd4jBackend backend) {
|
||||||
val out = Nd4j.createUninitialized(1,10);
|
val out = Nd4j.createUninitialized(1,10);
|
||||||
BitCast op = new BitCast(Nd4j.zeros(DataType.DOUBLE,1,10), DataType.INT.toInt(), out);
|
BitCast op = new BitCast(Nd4j.zeros(DataType.DOUBLE,1,10), DataType.INT.toInt(), out);
|
||||||
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
||||||
|
@ -1283,8 +1285,6 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Tag(TagNames.NEEDS_VERIFY)
|
|
||||||
@Disabled("Implementation needs verification")
|
|
||||||
public void testPolygamma(Nd4jBackend backend) {
|
public void testPolygamma(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3,3);
|
INDArray n = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3,3);
|
||||||
INDArray x = Nd4j.create(DataType.DOUBLE, 3,3);
|
INDArray x = Nd4j.create(DataType.DOUBLE, 3,3);
|
||||||
|
@ -1292,7 +1292,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
INDArray expected = Nd4j.createFromArray(new double[]{4.934802, -16.828796, 97.409088, -771.474243,
|
INDArray expected = Nd4j.createFromArray(new double[]{4.934802, -16.828796, 97.409088, -771.474243,
|
||||||
7691.113770f, -92203.460938f, 1290440.250000, -20644900.000000, 3.71595e+08}).reshape(3,3);
|
7691.113770f, -92203.460938f, 1290440.250000, -20644900.000000, 3.71595e+08}).reshape(3,3);
|
||||||
INDArray output = Nd4j.create(DataType.DOUBLE, expected.shape());
|
INDArray output = Nd4j.create(DataType.DOUBLE, expected.shape());
|
||||||
val op = new Polygamma(x,n,output);
|
val op = new Polygamma(n,x,output);
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
assertEquals(expected, output);
|
assertEquals(expected, output);
|
||||||
}
|
}
|
||||||
|
@ -1424,7 +1424,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAdjustHueShape(){
|
public void testAdjustHueShape(Nd4jBackend backend) {
|
||||||
INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f,
|
INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f,
|
||||||
0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f,
|
0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f,
|
||||||
0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f,
|
0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f,
|
||||||
|
@ -1470,7 +1470,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBitCastShape_3(){
|
public void testBitCastShape_3(Nd4jBackend backend) {
|
||||||
val x = Nd4j.createFromArray(new int[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(1, 4, 2);
|
val x = Nd4j.createFromArray(new int[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(1, 4, 2);
|
||||||
val e = Nd4j.createFromArray(new long[]{8589934593L, 17179869187L, 25769803781L, 34359738375L}).reshape(1, 4);
|
val e = Nd4j.createFromArray(new long[]{8589934593L, 17179869187L, 25769803781L, 34359738375L}).reshape(1, 4);
|
||||||
val z = Nd4j.exec(new BitCast(x, DataType.LONG.toInt()))[0];
|
val z = Nd4j.exec(new BitCast(x, DataType.LONG.toInt()))[0];
|
||||||
|
@ -1958,7 +1958,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBatchNormBpNHWC(){
|
public void testBatchNormBpNHWC(Nd4jBackend backend) {
|
||||||
//Nd4j.getEnvironment().allowHelpers(false); //Passes if helpers/MKLDNN is disabled
|
//Nd4j.getEnvironment().allowHelpers(false); //Passes if helpers/MKLDNN is disabled
|
||||||
|
|
||||||
INDArray in = Nd4j.rand(DataType.FLOAT, 2, 4, 4, 3);
|
INDArray in = Nd4j.rand(DataType.FLOAT, 2, 4, 4, 3);
|
||||||
|
@ -1971,13 +1971,13 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
assertEquals(eps, epsStrided);
|
assertEquals(eps, epsStrided);
|
||||||
|
|
||||||
INDArray out1eps = in.like();
|
INDArray out1eps = in.like().castTo(DataType.FLOAT);
|
||||||
INDArray out1m = mean.like();
|
INDArray out1m = mean.like().castTo(DataType.FLOAT);
|
||||||
INDArray out1v = var.like();
|
INDArray out1v = var.like().castTo(DataType.FLOAT);
|
||||||
|
|
||||||
INDArray out2eps = in.like();
|
INDArray out2eps = in.like().castTo(DataType.FLOAT);
|
||||||
INDArray out2m = mean.like();
|
INDArray out2m = mean.like().castTo(DataType.FLOAT);
|
||||||
INDArray out2v = var.like();
|
INDArray out2v = var.like().castTo(DataType.FLOAT);
|
||||||
|
|
||||||
DynamicCustomOp op1 = DynamicCustomOp.builder("batchnorm_bp")
|
DynamicCustomOp op1 = DynamicCustomOp.builder("batchnorm_bp")
|
||||||
.addInputs(in, mean, var, gamma, beta, eps)
|
.addInputs(in, mean, var, gamma, beta, eps)
|
||||||
|
@ -2004,7 +2004,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSpaceToDepthBadStrides(){
|
public void testSpaceToDepthBadStrides(Nd4jBackend backend) {
|
||||||
INDArray in = Nd4j.rand(DataType.FLOAT, 2, 3, 6, 6);
|
INDArray in = Nd4j.rand(DataType.FLOAT, 2, 3, 6, 6);
|
||||||
INDArray inBadStrides = in.permute(1,0,2,3).dup().permute(1,0,2,3);
|
INDArray inBadStrides = in.permute(1,0,2,3).dup().permute(1,0,2,3);
|
||||||
assertEquals(in, inBadStrides);
|
assertEquals(in, inBadStrides);
|
||||||
|
|
|
@ -24,6 +24,7 @@ import lombok.Getter;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
@ -45,10 +46,13 @@ import org.nd4j.linalg.dataset.api.preprocessor.stats.DistributionStats;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats;
|
import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;
|
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
import java.nio.file.Files;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
@ -61,83 +65,91 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
@NativeTag
|
@NativeTag
|
||||||
@Tag(TagNames.FILE_IO)
|
@Tag(TagNames.FILE_IO)
|
||||||
public class NormalizerSerializerTest extends BaseNd4jTestWithBackends {
|
public class NormalizerSerializerTest extends BaseNd4jTestWithBackends {
|
||||||
private File tmpFile;
|
@TempDir File tmpFile;
|
||||||
private NormalizerSerializer SUT;
|
private NormalizerSerializer SUT;
|
||||||
|
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void setUp() throws IOException {
|
public void setUp() throws IOException {
|
||||||
tmpFile = File.createTempFile("test", "preProcessor");
|
|
||||||
tmpFile.deleteOnExit();
|
|
||||||
|
|
||||||
SUT = NormalizerSerializer.getDefault();
|
SUT = NormalizerSerializer.getDefault();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testImagePreProcessingScaler() throws Exception {
|
public void testImagePreProcessingScaler(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(0,1);
|
ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(0,1);
|
||||||
SUT.write(imagePreProcessingScaler,tmpFile);
|
SUT.write(imagePreProcessingScaler,normalizerFile);
|
||||||
|
|
||||||
ImagePreProcessingScaler restored = SUT.restore(tmpFile);
|
ImagePreProcessingScaler restored = SUT.restore(normalizerFile);
|
||||||
assertEquals(imagePreProcessingScaler,restored);
|
assertEquals(imagePreProcessingScaler,restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNormalizerStandardizeNotFitLabels() throws Exception {
|
public void testNormalizerStandardizeNotFitLabels(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1),
|
NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1),
|
||||||
Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1));
|
Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1));
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
NormalizerStandardize restored = SUT.restore(tmpFile);
|
NormalizerStandardize restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNormalizerStandardizeFitLabels() throws Exception {
|
public void testNormalizerStandardizeFitLabels(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1),
|
NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1),
|
||||||
Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1), Nd4j.create(new double[] {4.5, 5.5}).reshape(1, -1),
|
Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1), Nd4j.create(new double[] {4.5, 5.5}).reshape(1, -1),
|
||||||
Nd4j.create(new double[] {6.5, 7.5}).reshape(1, -1));
|
Nd4j.create(new double[] {6.5, 7.5}).reshape(1, -1));
|
||||||
original.fitLabel(true);
|
original.fitLabel(true);
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
NormalizerStandardize restored = SUT.restore(tmpFile);
|
NormalizerStandardize restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNormalizerMinMaxScalerNotFitLabels() throws Exception {
|
public void testNormalizerMinMaxScalerNotFitLabels(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9);
|
NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9);
|
||||||
original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1));
|
original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1));
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
NormalizerMinMaxScaler restored = SUT.restore(tmpFile);
|
NormalizerMinMaxScaler restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNormalizerMinMaxScalerFitLabels() throws Exception {
|
public void testNormalizerMinMaxScalerFitLabels(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9);
|
NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9);
|
||||||
original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5}));
|
original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5}));
|
||||||
original.setLabelStats(Nd4j.create(new double[] {4.5, 5.5}), Nd4j.create(new double[] {6.5, 7.5}));
|
original.setLabelStats(Nd4j.create(new double[] {4.5, 5.5}), Nd4j.create(new double[] {6.5, 7.5}));
|
||||||
original.fitLabel(true);
|
original.fitLabel(true);
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
NormalizerMinMaxScaler restored = SUT.restore(tmpFile);
|
NormalizerMinMaxScaler restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiNormalizerStandardizeNotFitLabels() throws Exception {
|
public void testMultiNormalizerStandardizeNotFitLabels(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
MultiNormalizerStandardize original = new MultiNormalizerStandardize();
|
MultiNormalizerStandardize original = new MultiNormalizerStandardize();
|
||||||
original.setFeatureStats(asList(
|
original.setFeatureStats(asList(
|
||||||
new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1),
|
new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1),
|
||||||
|
@ -145,15 +157,17 @@ public class NormalizerSerializerTest extends BaseNd4jTestWithBackends {
|
||||||
new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}).reshape(1, -1),
|
new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}).reshape(1, -1),
|
||||||
Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1))));
|
Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1))));
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
MultiNormalizerStandardize restored = SUT.restore(tmpFile);
|
MultiNormalizerStandardize restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiNormalizerStandardizeFitLabels() throws Exception {
|
public void testMultiNormalizerStandardizeFitLabels(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
MultiNormalizerStandardize original = new MultiNormalizerStandardize();
|
MultiNormalizerStandardize original = new MultiNormalizerStandardize();
|
||||||
original.setFeatureStats(asList(
|
original.setFeatureStats(asList(
|
||||||
new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1),
|
new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1),
|
||||||
|
@ -168,30 +182,34 @@ public class NormalizerSerializerTest extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1))));
|
Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1))));
|
||||||
original.fitLabel(true);
|
original.fitLabel(true);
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
MultiNormalizerStandardize restored = SUT.restore(tmpFile);
|
MultiNormalizerStandardize restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiNormalizerMinMaxScalerNotFitLabels() throws Exception {
|
public void testMultiNormalizerMinMaxScalerNotFitLabels(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9);
|
MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9);
|
||||||
original.setFeatureStats(asList(
|
original.setFeatureStats(asList(
|
||||||
new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})),
|
new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})),
|
||||||
new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}),
|
new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}),
|
||||||
Nd4j.create(new double[] {7.5, 8.5, 9.5}))));
|
Nd4j.create(new double[] {7.5, 8.5, 9.5}))));
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
MultiNormalizerMinMaxScaler restored = SUT.restore(tmpFile);
|
MultiNormalizerMinMaxScaler restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiNormalizerMinMaxScalerFitLabels() throws Exception {
|
public void testMultiNormalizerMinMaxScalerFitLabels(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9);
|
MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9);
|
||||||
original.setFeatureStats(asList(
|
original.setFeatureStats(asList(
|
||||||
new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})),
|
new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})),
|
||||||
|
@ -204,28 +222,32 @@ public class NormalizerSerializerTest extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.create(new double[] {7.5, 8.5, 9.5}))));
|
Nd4j.create(new double[] {7.5, 8.5, 9.5}))));
|
||||||
original.fitLabel(true);
|
original.fitLabel(true);
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
MultiNormalizerMinMaxScaler restored = SUT.restore(tmpFile);
|
MultiNormalizerMinMaxScaler restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiNormalizerHybridEmpty() throws Exception {
|
public void testMultiNormalizerHybridEmpty(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
MultiNormalizerHybrid original = new MultiNormalizerHybrid();
|
MultiNormalizerHybrid original = new MultiNormalizerHybrid();
|
||||||
original.setInputStats(new HashMap<Integer, NormalizerStats>());
|
original.setInputStats(new HashMap<>());
|
||||||
original.setOutputStats(new HashMap<Integer, NormalizerStats>());
|
original.setOutputStats(new HashMap<>());
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
MultiNormalizerHybrid restored = SUT.restore(tmpFile);
|
MultiNormalizerHybrid restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiNormalizerHybridGlobalStats() throws Exception {
|
public void testMultiNormalizerHybridGlobalStats(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
MultiNormalizerHybrid original = new MultiNormalizerHybrid().minMaxScaleAllInputs().standardizeAllOutputs();
|
MultiNormalizerHybrid original = new MultiNormalizerHybrid().minMaxScaleAllInputs().standardizeAllOutputs();
|
||||||
|
|
||||||
Map<Integer, NormalizerStats> inputStats = new HashMap<>();
|
Map<Integer, NormalizerStats> inputStats = new HashMap<>();
|
||||||
|
@ -239,15 +261,17 @@ public class NormalizerSerializerTest extends BaseNd4jTestWithBackends {
|
||||||
original.setInputStats(inputStats);
|
original.setInputStats(inputStats);
|
||||||
original.setOutputStats(outputStats);
|
original.setOutputStats(outputStats);
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
MultiNormalizerHybrid restored = SUT.restore(tmpFile);
|
MultiNormalizerHybrid restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultiNormalizerHybridGlobalAndSpecificStats() throws Exception {
|
public void testMultiNormalizerHybridGlobalAndSpecificStats(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
MultiNormalizerHybrid original = new MultiNormalizerHybrid().standardizeAllInputs().minMaxScaleInput(0, -5, 5)
|
MultiNormalizerHybrid original = new MultiNormalizerHybrid().standardizeAllInputs().minMaxScaleInput(0, -5, 5)
|
||||||
.minMaxScaleAllOutputs(-10, 10).standardizeOutput(1);
|
.minMaxScaleAllOutputs(-10, 10).standardizeOutput(1);
|
||||||
|
|
||||||
|
@ -262,29 +286,35 @@ public class NormalizerSerializerTest extends BaseNd4jTestWithBackends {
|
||||||
original.setInputStats(inputStats);
|
original.setInputStats(inputStats);
|
||||||
original.setOutputStats(outputStats);
|
original.setOutputStats(outputStats);
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
MultiNormalizerHybrid restored = SUT.restore(tmpFile);
|
MultiNormalizerHybrid restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test()
|
|
||||||
public void testCustomNormalizerWithoutRegisteredStrategy() throws Exception {
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
public void testCustomNormalizerWithoutRegisteredStrategy(Nd4jBackend backend) throws Exception {
|
||||||
assertThrows(RuntimeException.class, () -> {
|
assertThrows(RuntimeException.class, () -> {
|
||||||
SUT.write(new MyNormalizer(123), tmpFile);
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
|
SUT.write(new MyNormalizer(123), normalizerFile);
|
||||||
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCustomNormalizer() throws Exception {
|
public void testCustomNormalizer(Nd4jBackend backend) throws Exception {
|
||||||
|
File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile();
|
||||||
|
|
||||||
MyNormalizer original = new MyNormalizer(42);
|
MyNormalizer original = new MyNormalizer(42);
|
||||||
|
|
||||||
SUT.addStrategy(new MyNormalizerSerializerStrategy());
|
SUT.addStrategy(new MyNormalizerSerializerStrategy());
|
||||||
|
|
||||||
SUT.write(original, tmpFile);
|
SUT.write(original, normalizerFile);
|
||||||
MyNormalizer restored = SUT.restore(tmpFile);
|
MyNormalizer restored = SUT.restore(normalizerFile);
|
||||||
|
|
||||||
assertEquals(original, restored);
|
assertEquals(original, restored);
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,6 +42,7 @@ import org.nd4j.common.util.ArrayUtil;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -248,7 +249,8 @@ public class Nd4jTest extends BaseNd4jTestWithBackends {
|
||||||
byte[] dataTwo = new byte[floatBuffer.capacity()];
|
byte[] dataTwo = new byte[floatBuffer.capacity()];
|
||||||
floatBuffer.get(dataTwo);
|
floatBuffer.get(dataTwo);
|
||||||
assertArrayEquals(originalData,dataTwo);
|
assertArrayEquals(originalData,dataTwo);
|
||||||
floatBuffer.position(0);
|
Buffer buffer = (Buffer) floatBuffer;
|
||||||
|
buffer.position(0);
|
||||||
|
|
||||||
DataBuffer dataBuffer = Nd4j.createBuffer(new FloatPointer(floatBuffer.asFloatBuffer()),linspace.length(), DataType.FLOAT);
|
DataBuffer dataBuffer = Nd4j.createBuffer(new FloatPointer(floatBuffer.asFloatBuffer()),linspace.length(), DataType.FLOAT);
|
||||||
assertArrayEquals(new float[]{1,2,3,4}, dataBuffer.asFloat(), 1e-5f);
|
assertArrayEquals(new float[]{1,2,3,4}, dataBuffer.asFloat(), 1e-5f);
|
||||||
|
|
|
@ -23,6 +23,8 @@ package org.nd4j.linalg.ops;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
@ -116,8 +118,10 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
public void testDistance() throws Exception {
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
|
public void testDistance(Nd4jBackend backend) throws Exception {
|
||||||
INDArray matrix = Nd4j.rand(new int[] {400,10});
|
INDArray matrix = Nd4j.rand(new int[] {400,10});
|
||||||
INDArray rowVector = matrix.getRow(70);
|
INDArray rowVector = matrix.getRow(70);
|
||||||
INDArray resultArr = Nd4j.zeros(400,1);
|
INDArray resultArr = Nd4j.zeros(400,1);
|
||||||
|
@ -127,8 +131,6 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends {
|
||||||
System.out.println("Ran!");
|
System.out.println("Ran!");
|
||||||
});
|
});
|
||||||
|
|
||||||
Thread.sleep(600000);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
|
|
@ -82,11 +82,9 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@NativeTag
|
@NativeTag
|
||||||
public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
|
public class OpExecutionerTestsC extends BaseNd4jTestWithBackends {
|
||||||
DataType initialType = Nd4j.dataType();
|
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
public void after() {
|
public void after() {
|
||||||
Nd4j.setDataType(this.initialType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,8 @@ package org.nd4j.linalg.profiling;
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
@ -39,6 +41,7 @@ import org.nd4j.linalg.profiler.ProfilerConfig;
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
public class InfNanTests extends BaseNd4jTestWithBackends {
|
public class InfNanTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,9 @@ import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
|
import org.junit.jupiter.api.parallel.Isolated;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.common.tests.tags.NativeTag;
|
import org.nd4j.common.tests.tags.NativeTag;
|
||||||
|
@ -52,6 +55,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Isolated
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
public class OperationProfilerTests extends BaseNd4jTestWithBackends {
|
public class OperationProfilerTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
@ -229,9 +234,10 @@ public class OperationProfilerTests extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_NON_EWS_ACCESS));
|
assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_NON_EWS_ACCESS));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@ParameterizedTest
|
||||||
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testBadTad4(Nd4jBackend backend) {
|
public void testBadTad4(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.create(2, 4, 5, 6);
|
INDArray x = Nd4j.create(DataType.DOUBLE,2, 4, 5, 6);
|
||||||
|
|
||||||
Pair<DataBuffer, DataBuffer> pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 3);
|
Pair<DataBuffer, DataBuffer> pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 3);
|
||||||
|
|
||||||
|
@ -473,7 +479,7 @@ public class OperationProfilerTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNanPanic(){
|
public void testNanPanic(Nd4jBackend backend) {
|
||||||
try {
|
try {
|
||||||
DynamicCustomOp op = DynamicCustomOp.builder("add")
|
DynamicCustomOp op = DynamicCustomOp.builder("add")
|
||||||
.addInputs(Nd4j.valueArrayOf(10, Double.NaN).castTo(DataType.DOUBLE), Nd4j.scalar(0.0))
|
.addInputs(Nd4j.valueArrayOf(10, Double.NaN).castTo(DataType.DOUBLE), Nd4j.scalar(0.0))
|
||||||
|
|
|
@ -441,6 +441,7 @@ public class RandomTests extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public void testStepOver1(Nd4jBackend backend) {
|
public void testStepOver1(Nd4jBackend backend) {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
|
||||||
|
@ -466,6 +467,8 @@ public class RandomTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public void testSum_119(Nd4jBackend backend) {
|
public void testSum_119(Nd4jBackend backend) {
|
||||||
INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000);
|
INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000);
|
||||||
val sum = z2.sumNumber().doubleValue();
|
val sum = z2.sumNumber().doubleValue();
|
||||||
|
@ -474,6 +477,8 @@ public class RandomTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public void testLegacyDistribution1(Nd4jBackend backend) {
|
public void testLegacyDistribution1(Nd4jBackend backend) {
|
||||||
NormalDistribution distribution = new NormalDistribution(new DefaultRandom(), 0.0, 1.0);
|
NormalDistribution distribution = new NormalDistribution(new DefaultRandom(), 0.0, 1.0);
|
||||||
INDArray z1 = distribution.sample(new int[] {1, 1000000});
|
INDArray z1 = distribution.sample(new int[] {1, 1000000});
|
||||||
|
@ -923,9 +928,10 @@ public class RandomTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public void testDeallocation1() throws Exception {
|
public void testDeallocation1() throws Exception {
|
||||||
|
for(int i = 0; i < 1000; i++) {
|
||||||
while (true) {
|
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
random1.nextInt();
|
random1.nextInt();
|
||||||
|
|
||||||
|
@ -934,6 +940,7 @@ public class RandomTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void someTest(Nd4jBackend backend) {
|
public void someTest(Nd4jBackend backend) {
|
||||||
|
|
|
@ -29,6 +29,8 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -70,6 +72,7 @@ import java.util.Map;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@NativeTag
|
@NativeTag
|
||||||
@Tag(TagNames.RNG)
|
@Tag(TagNames.RNG)
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
public class RngValidationTests extends BaseNd4jTestWithBackends {
|
public class RngValidationTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,6 +132,8 @@ public class RngValidationTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
|
@Tag(TagNames.NEEDS_VERIFY)
|
||||||
public void validateRngDistributions(Nd4jBackend backend){
|
public void validateRngDistributions(Nd4jBackend backend){
|
||||||
List<TestCase> testCases = new ArrayList<>();
|
List<TestCase> testCases = new ArrayList<>();
|
||||||
for(DataType type : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
for(DataType type : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
||||||
|
@ -264,7 +269,7 @@ public class RngValidationTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
int count = 1;
|
int count = 1;
|
||||||
for(TestCase tc : testCases){
|
for(TestCase tc : testCases) {
|
||||||
log.info("Starting test case: {} of {}", count, testCases.size());
|
log.info("Starting test case: {} of {}", count, testCases.size());
|
||||||
log.info("{}", tc);
|
log.info("{}", tc);
|
||||||
|
|
||||||
|
@ -314,7 +319,7 @@ public class RngValidationTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(z, z2);
|
assertEquals(z, z2);
|
||||||
|
|
||||||
//Check mean, stdev
|
//Check mean, stdev
|
||||||
if(tc.getExpectedMean() != null){
|
if(tc.getExpectedMean() != null) {
|
||||||
double mean = z.meanNumber().doubleValue();
|
double mean = z.meanNumber().doubleValue();
|
||||||
double re = relError(tc.getExpectedMean(), mean);
|
double re = relError(tc.getExpectedMean(), mean);
|
||||||
double ae = Math.abs(tc.getExpectedMean() - mean);
|
double ae = Math.abs(tc.getExpectedMean() - mean);
|
||||||
|
|
|
@ -44,6 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Tag(TagNames.JACKSON_SERDE)
|
@Tag(TagNames.JACKSON_SERDE)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
public class LargeSerDeTests extends BaseNd4jTestWithBackends {
|
public class LargeSerDeTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
|
|
@ -42,6 +42,7 @@ import org.nd4j.common.io.ClassPathResource;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
@ -56,7 +57,8 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testToNpyFormat(Nd4jBackend backend) throws Exception {
|
public void testToNpyFormat(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
val dir = testDir.toFile();
|
val dir = testDir.resolve("new-dir-" + UUID.randomUUID().toString()).toFile();
|
||||||
|
assertTrue(dir.mkdirs());
|
||||||
new ClassPathResource("numpy_arrays/").copyDirectory(dir);
|
new ClassPathResource("numpy_arrays/").copyDirectory(dir);
|
||||||
|
|
||||||
File[] files = dir.listFiles();
|
File[] files = dir.listFiles();
|
||||||
|
@ -107,14 +109,15 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends {
|
||||||
public void testToNpyFormatScalars(Nd4jBackend backend) throws Exception {
|
public void testToNpyFormatScalars(Nd4jBackend backend) throws Exception {
|
||||||
// File dir = new File("C:\\DL4J\\Git\\dl4j-test-resources\\src\\main\\resources\\numpy_arrays\\scalar");
|
// File dir = new File("C:\\DL4J\\Git\\dl4j-test-resources\\src\\main\\resources\\numpy_arrays\\scalar");
|
||||||
|
|
||||||
val dir = testDir.toFile();
|
val dir = testDir.resolve("new-path0" + UUID.randomUUID().toString()).toFile();
|
||||||
|
dir.mkdirs();
|
||||||
new ClassPathResource("numpy_arrays/scalar/").copyDirectory(dir);
|
new ClassPathResource("numpy_arrays/scalar/").copyDirectory(dir);
|
||||||
|
|
||||||
File[] files = dir.listFiles();
|
File[] files = dir.listFiles();
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
|
||||||
for(File f : files){
|
for(File f : files){
|
||||||
if(!f.getPath().endsWith(".npy")){
|
if(!f.getPath().endsWith(".npy")) {
|
||||||
log.warn("Skipping: {}", f);
|
log.warn("Skipping: {}", f);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -161,7 +164,8 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNpzReading(Nd4jBackend backend) throws Exception {
|
public void testNpzReading(Nd4jBackend backend) throws Exception {
|
||||||
|
|
||||||
val dir = testDir.toFile();
|
val dir = testDir.resolve("new-folder-npz").toFile();
|
||||||
|
dir.mkdirs();
|
||||||
new ClassPathResource("numpy_arrays/npz/").copyDirectory(dir);
|
new ClassPathResource("numpy_arrays/npz/").copyDirectory(dir);
|
||||||
|
|
||||||
File[] files = dir.listFiles();
|
File[] files = dir.listFiles();
|
||||||
|
@ -222,7 +226,8 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNpy(Nd4jBackend backend) throws Exception {
|
public void testNpy(Nd4jBackend backend) throws Exception {
|
||||||
for(boolean empty : new boolean[]{false, true}) {
|
for(boolean empty : new boolean[]{false, true}) {
|
||||||
val dir = testDir.toFile();
|
val dir = testDir.resolve("new-dir-1-" + UUID.randomUUID().toString()).toFile();
|
||||||
|
assertTrue(dir.mkdirs());
|
||||||
if(!empty) {
|
if(!empty) {
|
||||||
new ClassPathResource("numpy_arrays/npy/3,4/").copyDirectory(dir);
|
new ClassPathResource("numpy_arrays/npy/3,4/").copyDirectory(dir);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -403,13 +403,13 @@ public class ShapeTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRavel(Nd4jBackend backend) {
|
public void testRavel(Nd4jBackend backend) {
|
||||||
INDArray linspace = Nd4j.linspace(1, 4, 4).reshape(2, 2);
|
INDArray linspace = Nd4j.linspace(1, 4, 4,DataType.DOUBLE).reshape(2, 2);
|
||||||
INDArray asseriton = Nd4j.linspace(1, 4, 4);
|
INDArray asseriton = Nd4j.linspace(1, 4, 4,DataType.DOUBLE);
|
||||||
INDArray raveled = linspace.ravel();
|
INDArray raveled = linspace.ravel();
|
||||||
assertEquals(asseriton, raveled);
|
assertEquals(asseriton, raveled);
|
||||||
|
|
||||||
INDArray tensorLinSpace = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2);
|
INDArray tensorLinSpace = Nd4j.linspace(1, 16, 16,DataType.DOUBLE).reshape(2, 2, 2, 2);
|
||||||
INDArray linspaced = Nd4j.linspace(1, 16, 16);
|
INDArray linspaced = Nd4j.linspace(1, 16, 16,DataType.DOUBLE);
|
||||||
INDArray tensorLinspaceRaveled = tensorLinSpace.ravel();
|
INDArray tensorLinspaceRaveled = tensorLinSpace.ravel();
|
||||||
assertEquals(linspaced, tensorLinspaceRaveled);
|
assertEquals(linspaced, tensorLinspaceRaveled);
|
||||||
|
|
||||||
|
|
|
@ -236,7 +236,6 @@ public class ConcatTestsC extends BaseNd4jTestWithBackends {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testConcat3dv2(Nd4jBackend backend) {
|
public void testConcat3dv2(Nd4jBackend backend) {
|
||||||
|
|
|
@ -21,9 +21,9 @@
|
||||||
package org.nd4j.linalg.specials;
|
package org.nd4j.linalg.specials;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.*;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.junit.jupiter.api.parallel.Isolated;
|
import org.junit.jupiter.api.parallel.Isolated;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -49,19 +49,30 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@NativeTag
|
@NativeTag
|
||||||
@Isolated
|
@Isolated
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public class LongTests extends BaseNd4jTestWithBackends {
|
public class LongTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
DataType initialType = Nd4j.dataType();
|
DataType initialType = Nd4j.dataType();
|
||||||
|
@BeforeEach
|
||||||
|
public void beforeEach() {
|
||||||
|
System.gc();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
public void afterEach() {
|
||||||
|
System.gc();
|
||||||
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
public void testSomething1(Nd4jBackend backend) {
|
public void testSomething1(Nd4jBackend backend) {
|
||||||
// we create 2D array, total nr. of elements is 2.4B elements, > MAX_INT
|
// we create 2D array, total nr. of elements is 2.4B elements, > MAX_INT
|
||||||
INDArray huge = Nd4j.create(8000000, 300);
|
INDArray huge = Nd4j.create(DataType.INT8,8000000, 300);
|
||||||
|
|
||||||
// we apply element-wise scalar ops, just to make sure stuff still works
|
// we apply element-wise scalar ops, just to make sure stuff still works
|
||||||
huge.subi(0.5).divi(2);
|
huge.subi(1).divi(2);
|
||||||
|
|
||||||
|
|
||||||
// now we're checking different rows, they should NOT equal
|
// now we're checking different rows, they should NOT equal
|
||||||
|
@ -86,10 +97,10 @@ public class LongTests extends BaseNd4jTestWithBackends {
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
public void testSomething2(Nd4jBackend backend) {
|
public void testSomething2(Nd4jBackend backend) {
|
||||||
// we create 2D array, total nr. of elements is 2.4B elements, > MAX_INT
|
// we create 2D array, total nr. of elements is 2.4B elements, > MAX_INT
|
||||||
INDArray huge = Nd4j.create(100, 10);
|
INDArray huge = Nd4j.create(DataType.INT8,100, 10);
|
||||||
|
|
||||||
// we apply element-wise scalar ops, just to make sure stuff still works
|
// we apply element-wise scalar ops, just to make sure stuff still works
|
||||||
huge.subi(0.5).divi(2);
|
huge.subi(1).divi(2);
|
||||||
|
|
||||||
|
|
||||||
// now we're checking different rows, they should NOT equal
|
// now we're checking different rows, they should NOT equal
|
||||||
|
@ -113,7 +124,7 @@ public class LongTests extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
public void testLongTadOffsets1(Nd4jBackend backend) {
|
public void testLongTadOffsets1(Nd4jBackend backend) {
|
||||||
INDArray huge = Nd4j.create(230000000, 10);
|
INDArray huge = Nd4j.create(DataType.INT8,230000000, 10);
|
||||||
|
|
||||||
Pair<DataBuffer, DataBuffer> tad = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(huge, 1);
|
Pair<DataBuffer, DataBuffer> tad = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(huge, 1);
|
||||||
|
|
||||||
|
@ -125,10 +136,10 @@ public class LongTests extends BaseNd4jTestWithBackends {
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
public void testLongTadOp1(Nd4jBackend backend) {
|
public void testLongTadOp1(Nd4jBackend backend) {
|
||||||
|
|
||||||
double exp = Transforms.manhattanDistance(Nd4j.create(1000).assign(1.0), Nd4j.create(1000).assign(2.0));
|
double exp = Transforms.manhattanDistance(Nd4j.create(DataType.INT16,1000).assign(1.0), Nd4j.create(DataType.INT16,1000).assign(2.0));
|
||||||
|
|
||||||
INDArray hugeX = Nd4j.create(2200000, 1000).assign(1.0);
|
INDArray hugeX = Nd4j.create(DataType.INT16,2200000, 1000).assign(1.0);
|
||||||
INDArray hugeY = Nd4j.create(1, 1000).assign(2.0);
|
INDArray hugeY = Nd4j.create(DataType.INT16,1, 1000).assign(2.0);
|
||||||
|
|
||||||
for (int x = 0; x < hugeX.rows(); x++) {
|
for (int x = 0; x < hugeX.rows(); x++) {
|
||||||
assertEquals(1000, hugeX.getRow(x).sumNumber().intValue(),"Failed at row " + x);
|
assertEquals(1000, hugeX.getRow(x).sumNumber().intValue(),"Failed at row " + x);
|
||||||
|
@ -144,9 +155,8 @@ public class LongTests extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
public void testLongTadOp2(Nd4jBackend backend) {
|
public void testLongTadOp2(Nd4jBackend backend) {
|
||||||
|
INDArray hugeX = Nd4j.create(DataType.INT16,2300000, 1000).assign(1.0);
|
||||||
INDArray hugeX = Nd4j.create(2300000, 1000).assign(1.0);
|
hugeX.addiRowVector(Nd4j.create(DataType.INT16,1000).assign(2.0));
|
||||||
hugeX.addiRowVector(Nd4j.create(1000).assign(2.0));
|
|
||||||
|
|
||||||
for (int x = 0; x < hugeX.rows(); x++) {
|
for (int x = 0; x < hugeX.rows(); x++) {
|
||||||
assertEquals( hugeX.getRow(x).sumNumber().intValue(),3000,"Failed at row " + x);
|
assertEquals( hugeX.getRow(x).sumNumber().intValue(),3000,"Failed at row " + x);
|
||||||
|
@ -158,8 +168,8 @@ public class LongTests extends BaseNd4jTestWithBackends {
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
public void testLongTadOp2_micro(Nd4jBackend backend) {
|
public void testLongTadOp2_micro(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray hugeX = Nd4j.create(230, 1000).assign(1.0);
|
INDArray hugeX = Nd4j.create(DataType.INT16,230, 1000).assign(1.0);
|
||||||
hugeX.addiRowVector(Nd4j.create(1000).assign(2.0));
|
hugeX.addiRowVector(Nd4j.create(DataType.INT16,1000).assign(2.0));
|
||||||
|
|
||||||
for (int x = 0; x < hugeX.rows(); x++) {
|
for (int x = 0; x < hugeX.rows(); x++) {
|
||||||
assertEquals( 3000, hugeX.getRow(x).sumNumber().intValue(),"Failed at row " + x);
|
assertEquals( 3000, hugeX.getRow(x).sumNumber().intValue(),"Failed at row " + x);
|
||||||
|
@ -171,7 +181,7 @@ public class LongTests extends BaseNd4jTestWithBackends {
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
public void testLongTadOp3(Nd4jBackend backend) {
|
public void testLongTadOp3(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray hugeX = Nd4j.create(2300000, 1000).assign(1.0);
|
INDArray hugeX = Nd4j.create(DataType.INT16,2300000, 1000).assign(1.0);
|
||||||
INDArray mean = hugeX.mean(1);
|
INDArray mean = hugeX.mean(1);
|
||||||
|
|
||||||
for (int x = 0; x < hugeX.rows(); x++) {
|
for (int x = 0; x < hugeX.rows(); x++) {
|
||||||
|
@ -184,7 +194,7 @@ public class LongTests extends BaseNd4jTestWithBackends {
|
||||||
@Tag(TagNames.LONG_TEST)
|
@Tag(TagNames.LONG_TEST)
|
||||||
public void testLongTadOp4(Nd4jBackend backend) {
|
public void testLongTadOp4(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray hugeX = Nd4j.create(2300000, 1000).assign(1.0);
|
INDArray hugeX = Nd4j.create(DataType.INT8,2300000, 1000).assign(1.0);
|
||||||
INDArray mean = hugeX.argMax(1);
|
INDArray mean = hugeX.argMax(1);
|
||||||
|
|
||||||
for (int x = 0; x < hugeX.rows(); x++) {
|
for (int x = 0; x < hugeX.rows(); x++) {
|
||||||
|
@ -199,7 +209,7 @@ public class LongTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
List<INDArray> list = new ArrayList<>();
|
List<INDArray> list = new ArrayList<>();
|
||||||
for (int i = 0; i < 2300000; i++) {
|
for (int i = 0; i < 2300000; i++) {
|
||||||
list.add(Nd4j.create(1000).assign(2.0));
|
list.add(Nd4j.create(DataType.INT8,1000).assign(2.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray hugeX = Nd4j.vstack(list);
|
INDArray hugeX = Nd4j.vstack(list);
|
||||||
|
|
|
@ -23,6 +23,8 @@ package org.nd4j.linalg.workspace;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.jupiter.api.*;
|
import org.junit.jupiter.api.*;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
@ -54,6 +56,7 @@ import static org.nd4j.linalg.api.buffer.DataType.DOUBLE;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Tag(TagNames.WORKSPACES)
|
@Tag(TagNames.WORKSPACES)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
public class BasicWorkspaceTests extends BaseNd4jTestWithBackends {
|
public class BasicWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
DataType initialType = Nd4j.dataType();
|
DataType initialType = Nd4j.dataType();
|
||||||
|
|
||||||
|
@ -959,6 +962,7 @@ public class BasicWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
public void testMmap1(Nd4jBackend backend) {
|
public void testMmap1(Nd4jBackend backend) {
|
||||||
// we don't support MMAP on cuda yet
|
// we don't support MMAP on cuda yet
|
||||||
if (Nd4j.getExecutioner().getClass().getName().toLowerCase().contains("cuda"))
|
if (Nd4j.getExecutioner().getClass().getName().toLowerCase().contains("cuda"))
|
||||||
|
@ -989,12 +993,13 @@ public class BasicWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
|
@Disabled("Still failing even with single thread execution")
|
||||||
public void testMmap2(Nd4jBackend backend) throws Exception {
|
public void testMmap2(Nd4jBackend backend) throws Exception {
|
||||||
// we don't support MMAP on cuda yet
|
// we don't support MMAP on cuda yet
|
||||||
if (Nd4j.getExecutioner().getClass().getName().toLowerCase().contains("cuda"))
|
if (!backend.getEnvironment().isCPU())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
File tmp = File.createTempFile("tmp", "fdsfdf");
|
File tmp = File.createTempFile("tmp", "fdsfdf");
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.workspace;
|
package org.nd4j.linalg.workspace;
|
||||||
|
|
||||||
|
import lombok.SneakyThrows;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
@ -65,8 +66,11 @@ public class CyclicWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SneakyThrows
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public void testGc(Nd4jBackend backend) {
|
public void testGc(Nd4jBackend backend) {
|
||||||
val indArray = Nd4j.create(4, 4);
|
val indArray = Nd4j.create(4, 4);
|
||||||
indArray.putRow(0, Nd4j.create(new float[]{0, 2, -2, 0}));
|
indArray.putRow(0, Nd4j.create(new float[]{0, 2, -2, 0}));
|
||||||
|
@ -76,7 +80,7 @@ public class CyclicWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
for (int i = 0; i < 100000000; i++) {
|
for (int i = 0; i < 100000000; i++) {
|
||||||
indArray.getRow(i % 3);
|
indArray.getRow(i % 3);
|
||||||
//Thread.sleep(1);
|
Thread.sleep(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
@ -48,8 +50,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Tag(TagNames.WORKSPACES)
|
@Tag(TagNames.WORKSPACES)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
public class DebugModeTests extends BaseNd4jTestWithBackends {
|
public class DebugModeTests extends BaseNd4jTestWithBackends {
|
||||||
DataType initialType = Nd4j.dataType();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
@ -62,12 +64,11 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
public void shutUp() {
|
public void shutUp() {
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(null);
|
Nd4j.getMemoryManager().setCurrentWorkspace(null);
|
||||||
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
||||||
Nd4j.setDataType(this.initialType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
@Disabled
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
public void testVariableTimeSeries1(Nd4jBackend backend) {
|
public void testVariableTimeSeries1(Nd4jBackend backend) {
|
||||||
WorkspaceConfiguration configuration = WorkspaceConfiguration
|
WorkspaceConfiguration configuration = WorkspaceConfiguration
|
||||||
.builder()
|
.builder()
|
||||||
|
@ -80,28 +81,28 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1");
|
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1");
|
||||||
|
|
||||||
assertEquals(0, workspace.getStepNumber());
|
assertEquals(0, workspace.getStepNumber());
|
||||||
|
|
||||||
long requiredMemory = 1000 * Nd4j.sizeOfDataType();
|
long requiredMemory = 1000 * DataType.DOUBLE.width();
|
||||||
long shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8));
|
long shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8));
|
||||||
assertEquals(requiredMemory, workspace.getSpilledSize());
|
assertEquals(requiredMemory, workspace.getSpilledSize());
|
||||||
assertEquals(shiftedSize, workspace.getInitialBlockSize());
|
assertEquals(shiftedSize, workspace.getInitialBlockSize());
|
||||||
assertEquals(workspace.getInitialBlockSize() * 4, workspace.getCurrentSize());
|
assertEquals(workspace.getInitialBlockSize() * 4, workspace.getCurrentSize());
|
||||||
|
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS1")) {
|
||||||
Nd4j.create(2000);
|
Nd4j.create(DataType.DOUBLE,2000);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(0, workspace.getStepNumber());
|
assertEquals(0, workspace.getStepNumber());
|
||||||
|
|
||||||
assertEquals(1000 * Nd4j.sizeOfDataType(), workspace.getSpilledSize());
|
assertEquals(1000 * DataType.DOUBLE.width(), workspace.getSpilledSize());
|
||||||
assertEquals(2000 * Nd4j.sizeOfDataType(), workspace.getPinnedSize());
|
assertEquals(2000 * DataType.DOUBLE.width(), workspace.getPinnedSize());
|
||||||
|
|
||||||
assertEquals(0, workspace.getDeviceOffset());
|
assertEquals(0, workspace.getDeviceOffset());
|
||||||
|
|
||||||
|
@ -116,8 +117,8 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
for (int e = 0; e < 4; e++) {
|
for (int e = 0; e < 4; e++) {
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals((i + 1) * workspace.getInitialBlockSize(),
|
assertEquals((i + 1) * workspace.getInitialBlockSize(),
|
||||||
|
@ -144,9 +145,9 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
// we just do huge loop now, with pinned stuff in it
|
// we just do huge loop now, with pinned stuff in it
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
|
|
||||||
assertEquals(1500 * Nd4j.sizeOfDataType(), workspace.getThisCycleAllocations());
|
assertEquals(1500 * Nd4j.sizeOfDataType(), workspace.getThisCycleAllocations());
|
||||||
}
|
}
|
||||||
|
@ -160,8 +161,8 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
// and we do another clean loo, without pinned stuff in it, to ensure all pinned allocates are gone
|
// and we do another clean loo, without pinned stuff in it, to ensure all pinned allocates are gone
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,13 +187,12 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
// workspace.enableDebug(true);
|
// workspace.enableDebug(true);
|
||||||
|
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(0, workspace.getStepNumber());
|
assertEquals(0, workspace.getStepNumber());
|
||||||
|
long requiredMemory = 1000 * DataType.DOUBLE.width();
|
||||||
long requiredMemory = 1000 * Nd4j.sizeOfDataType();
|
|
||||||
long shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8));
|
long shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8));
|
||||||
assertEquals(requiredMemory, workspace.getSpilledSize());
|
assertEquals(requiredMemory, workspace.getSpilledSize());
|
||||||
assertEquals(shiftedSize, workspace.getInitialBlockSize());
|
assertEquals(shiftedSize, workspace.getInitialBlockSize());
|
||||||
|
@ -200,9 +200,9 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
Nd4j.create(500);
|
Nd4j.create(DataType.DOUBLE,500);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,11 +226,11 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends {
|
||||||
Nd4jWorkspace workspace =
|
Nd4jWorkspace workspace =
|
||||||
(Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "WS109");
|
(Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "WS109");
|
||||||
|
|
||||||
INDArray row = Nd4j.linspace(1, 10, 10);
|
INDArray row = Nd4j.linspace(1, 10, 10).castTo(DataType.DOUBLE);
|
||||||
INDArray exp = Nd4j.create(10).assign(2.0);
|
INDArray exp = Nd4j.create(DataType.DOUBLE,10).assign(2.0);
|
||||||
INDArray result = null;
|
INDArray result = null;
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS109")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS109")) {
|
||||||
INDArray matrix = Nd4j.create(10, 10);
|
INDArray matrix = Nd4j.create(DataType.DOUBLE,10, 10);
|
||||||
for (int e = 0; e < matrix.rows(); e++)
|
for (int e = 0; e < matrix.rows(); e++)
|
||||||
matrix.getRow(e).assign(row);
|
matrix.getRow(e).assign(row);
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Tag;
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.parallel.Execution;
|
||||||
|
import org.junit.jupiter.api.parallel.ExecutionMode;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
|
||||||
|
@ -57,6 +59,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Tag(TagNames.WORKSPACES)
|
@Tag(TagNames.WORKSPACES)
|
||||||
@NativeTag
|
@NativeTag
|
||||||
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
||||||
private static final WorkspaceConfiguration basicConfiguration = WorkspaceConfiguration.builder().initialSize(81920)
|
private static final WorkspaceConfiguration basicConfiguration = WorkspaceConfiguration.builder().initialSize(81920)
|
||||||
.overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.NONE)
|
.overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.NONE)
|
||||||
|
@ -119,7 +122,6 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
||||||
public void shutUp() {
|
public void shutUp() {
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(null);
|
Nd4j.getMemoryManager().setCurrentWorkspace(null);
|
||||||
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
||||||
Nd4j.setDataType(this.initialType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -144,7 +146,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
||||||
for (int x = 0; x < 100; x++) {
|
for (int x = 0; x < 100; x++) {
|
||||||
try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager()
|
try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager()
|
||||||
.getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) {
|
.getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) {
|
||||||
INDArray array = Nd4j.create(100);
|
INDArray array = Nd4j.create(DataType.DOUBLE,100);
|
||||||
}
|
}
|
||||||
|
|
||||||
// only checking after workspace is initialized
|
// only checking after workspace is initialized
|
||||||
|
@ -174,7 +176,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
||||||
try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager()
|
try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager()
|
||||||
.getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) {
|
.getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) {
|
||||||
|
|
||||||
INDArray array = Nd4j.create(100);
|
INDArray array = Nd4j.create(DataType.DOUBLE,100);
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration,
|
Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration,
|
||||||
|
@ -200,7 +202,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMultithreading1() throws Exception {
|
public void testMultithreading1(Nd4jBackend backend) throws Exception {
|
||||||
final List<MemoryWorkspace> workspaces = new CopyOnWriteArrayList<>();
|
final List<MemoryWorkspace> workspaces = new CopyOnWriteArrayList<>();
|
||||||
Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration);
|
Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration);
|
||||||
|
|
||||||
|
@ -283,21 +285,23 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
|
@Disabled
|
||||||
|
@Tag(TagNames.NEEDS_VERIFY)
|
||||||
public void testNestedWorkspacesOverlap1(Nd4jBackend backend) {
|
public void testNestedWorkspacesOverlap1(Nd4jBackend backend) {
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
|
||||||
Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration);
|
Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration);
|
||||||
try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1").notifyScopeEntered()) {
|
try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1").notifyScopeEntered()) {
|
||||||
INDArray array = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f});
|
INDArray array = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f});
|
||||||
|
|
||||||
long reqMem = 5 * Nd4j.sizeOfDataType();
|
long reqMem = 5 * array.dataType().width();
|
||||||
assertEquals(reqMem + (Nd4jWorkspace.alignmentBase - reqMem % Nd4jWorkspace.alignmentBase), ws1.getPrimaryOffset());
|
long add = ((Nd4jWorkspace.alignmentBase / 2) - reqMem % (Nd4jWorkspace.alignmentBase / 2));
|
||||||
|
assertEquals(reqMem + add, ws1.getPrimaryOffset());
|
||||||
try (Nd4jWorkspace ws2 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS2").notifyScopeEntered()) {
|
try (Nd4jWorkspace ws2 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS2").notifyScopeEntered()) {
|
||||||
|
|
||||||
INDArray array2 = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f});
|
INDArray array2 = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f});
|
||||||
|
|
||||||
reqMem = 5 * Nd4j.sizeOfDataType();
|
reqMem = 5 * array2.dataType().width();
|
||||||
assertEquals(reqMem + (Nd4jWorkspace.alignmentBase - reqMem % Nd4jWorkspace.alignmentBase), ws1.getPrimaryOffset());
|
assertEquals(reqMem + ((Nd4jWorkspace.alignmentBase / 2) - reqMem % (Nd4jWorkspace.alignmentBase / 2)), ws1.getPrimaryOffset());
|
||||||
assertEquals(reqMem + (Nd4jWorkspace.alignmentBase - reqMem % Nd4jWorkspace.alignmentBase), ws2.getPrimaryOffset());
|
assertEquals(reqMem + ((Nd4jWorkspace.alignmentBase / 2) - reqMem % (Nd4jWorkspace.alignmentBase / 2)), ws2.getPrimaryOffset());
|
||||||
|
|
||||||
try (Nd4jWorkspace ws3 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1")
|
try (Nd4jWorkspace ws3 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1")
|
||||||
.notifyScopeBorrowed()) {
|
.notifyScopeBorrowed()) {
|
||||||
|
@ -305,8 +309,8 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
INDArray array3 = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f});
|
INDArray array3 = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f});
|
||||||
|
|
||||||
assertEquals(reqMem + (Nd4jWorkspace.alignmentBase - reqMem % Nd4jWorkspace.alignmentBase), ws2.getPrimaryOffset());
|
assertEquals(reqMem + ((Nd4jWorkspace.alignmentBase / 2) - reqMem % (Nd4jWorkspace.alignmentBase / 2)), ws2.getPrimaryOffset());
|
||||||
assertEquals((reqMem + (Nd4jWorkspace.alignmentBase - reqMem % Nd4jWorkspace.alignmentBase)) * 2, ws1.getPrimaryOffset());
|
assertEquals((reqMem + ((Nd4jWorkspace.alignmentBase / 2) - reqMem % (Nd4jWorkspace.alignmentBase / 2))) * 2, ws1.getPrimaryOffset());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -317,7 +321,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testWorkspacesSerde3() throws Exception {
|
public void testWorkspacesSerde3() throws Exception {
|
||||||
INDArray array = Nd4j.create(10).assign(1.0);
|
INDArray array = Nd4j.create(DataType.DOUBLE,10).assign(1.0);
|
||||||
INDArray restored = null;
|
INDArray restored = null;
|
||||||
|
|
||||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||||
|
@ -600,7 +604,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testReallocate1(Nd4jBackend backend) {
|
public void testReallocate1(Nd4jBackend backend) {
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) {
|
||||||
INDArray array = Nd4j.create(100);
|
INDArray array = Nd4j.create(DataType.DOUBLE,100);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -612,7 +616,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
||||||
assertEquals(100 * Nd4j.sizeOfDataType(), workspace.getCurrentSize());
|
assertEquals(100 * Nd4j.sizeOfDataType(), workspace.getCurrentSize());
|
||||||
|
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) {
|
||||||
INDArray array = Nd4j.create(1000);
|
INDArray array = Nd4j.create(DataType.DOUBLE,1000);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(1000 * Nd4j.sizeOfDataType(), workspace.getMaxCycleAllocations());
|
assertEquals(1000 * Nd4j.sizeOfDataType(), workspace.getMaxCycleAllocations());
|
||||||
|
@ -634,14 +638,14 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends {
|
||||||
public void testNestedWorkspaces11(Nd4jBackend backend) {
|
public void testNestedWorkspaces11(Nd4jBackend backend) {
|
||||||
for (int x = 1; x < 10; x++) {
|
for (int x = 1; x < 10; x++) {
|
||||||
try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) {
|
try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) {
|
||||||
INDArray array1 = Nd4j.create(100 * x);
|
INDArray array1 = Nd4j.create(DataType.DOUBLE,100 * x);
|
||||||
|
|
||||||
for (int i = 1; i < 10; i++) {
|
for (int i = 1; i < 10; i++) {
|
||||||
try (MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) {
|
try (MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) {
|
||||||
INDArray array2 = Nd4j.create(100 * x);
|
INDArray array2 = Nd4j.create(DataType.DOUBLE,100 * x);
|
||||||
for (int e = 1; e < 10; e++) {
|
for (int e = 1; e < 10; e++) {
|
||||||
try (MemoryWorkspace ws3 = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(basicConfiguration, "WS_1").notifyScopeBorrowed()) {
|
try (MemoryWorkspace ws3 = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(basicConfiguration, "WS_1").notifyScopeBorrowed()) {
|
||||||
INDArray array3 = Nd4j.create(100 * x);
|
INDArray array3 = Nd4j.create(DataType.DOUBLE,100 * x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,7 +55,7 @@ public abstract class BaseND4JTest {
|
||||||
* Override this method to set the default timeout for methods in the test class
|
* Override this method to set the default timeout for methods in the test class
|
||||||
*/
|
*/
|
||||||
public long getTimeoutMilliseconds(){
|
public long getTimeoutMilliseconds(){
|
||||||
return 90_000;
|
return 180_000;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -95,7 +95,7 @@ public abstract class BaseND4JTest {
|
||||||
/**
|
/**
|
||||||
* @return True if integration tests maven profile is enabled, false otherwise.
|
* @return True if integration tests maven profile is enabled, false otherwise.
|
||||||
*/
|
*/
|
||||||
public boolean isIntegrationTests(){
|
public boolean isIntegrationTests() {
|
||||||
if(integrationTest == null){
|
if(integrationTest == null){
|
||||||
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
|
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
|
||||||
integrationTest = Boolean.parseBoolean(prop);
|
integrationTest = Boolean.parseBoolean(prop);
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* * ******************************************************************************
|
||||||
|
* * *
|
||||||
|
* * *
|
||||||
|
* * * This program and the accompanying materials are made available under the
|
||||||
|
* * * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* * *
|
||||||
|
* * * See the NOTICE file distributed with this work for additional
|
||||||
|
* * * information regarding copyright ownership.
|
||||||
|
* * * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * * License for the specific language governing permissions and limitations
|
||||||
|
* * * under the License.
|
||||||
|
* * *
|
||||||
|
* * * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* * *****************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.common.tests.tags;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Tag;
|
||||||
|
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
@Target({ElementType.TYPE, ElementType.METHOD})
|
||||||
|
@Retention(RetentionPolicy.RUNTIME)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
public @interface ExpensiveTest {
|
||||||
|
}
|
|
@ -50,4 +50,5 @@ public class TagNames {
|
||||||
public final static String PYTHON = "python";
|
public final static String PYTHON = "python";
|
||||||
public final static String LONG_TEST = "long-running-test";
|
public final static String LONG_TEST = "long-running-test";
|
||||||
public final static String NEEDS_VERIFY = "needs-verify"; //tests that need verification of issue
|
public final static String NEEDS_VERIFY = "needs-verify"; //tests that need verification of issue
|
||||||
|
public final static String LARGE_RESOURCES = "large-resources";
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,6 +106,7 @@
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
<configuration>
|
||||||
|
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
|
||||||
<forkCount>${cpu.core.count}</forkCount>
|
<forkCount>${cpu.core.count}</forkCount>
|
||||||
<reuseForks>false</reuseForks>
|
<reuseForks>false</reuseForks>
|
||||||
<environmentVariables>
|
<environmentVariables>
|
||||||
|
@ -116,7 +117,8 @@
|
||||||
<include>*.java</include>
|
<include>*.java</include>
|
||||||
<include>**/*.java</include>
|
<include>**/*.java</include>
|
||||||
</includes>
|
</includes>
|
||||||
<argLine> -Xmx8g </argLine>
|
<argLine>-Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine>
|
||||||
|
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
@ -140,9 +142,11 @@
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
<configuration>
|
||||||
|
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
|
||||||
<forkCount>${cpu.core.count}</forkCount>
|
<forkCount>${cpu.core.count}</forkCount>
|
||||||
<reuseForks>false</reuseForks>
|
<reuseForks>false</reuseForks>
|
||||||
<argLine>-Xmx8g</argLine>
|
<argLine>-Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine>
|
||||||
|
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
|
|
@ -98,6 +98,7 @@
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
<configuration>
|
||||||
|
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
|
||||||
<forkCount>${cpu.core.count}</forkCount>
|
<forkCount>${cpu.core.count}</forkCount>
|
||||||
<reuseForks>false</reuseForks>
|
<reuseForks>false</reuseForks>
|
||||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
<testSourceDirectory>src/test/java</testSourceDirectory>
|
||||||
|
@ -105,7 +106,8 @@
|
||||||
<include>*.java</include>
|
<include>*.java</include>
|
||||||
<include>**/*.java</include>
|
<include>**/*.java</include>
|
||||||
</includes>
|
</includes>
|
||||||
<argLine> </argLine>
|
<argLine>-Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine>
|
||||||
|
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
|
|
@ -20,35 +20,61 @@
|
||||||
|
|
||||||
package org.nd4j;
|
package org.nd4j;
|
||||||
|
|
||||||
|
import com.sun.jna.Platform;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.api.java.function.VoidFunction;
|
import org.apache.spark.api.java.function.VoidFunction;
|
||||||
import org.apache.spark.broadcast.Broadcast;
|
import org.apache.spark.broadcast.Broadcast;
|
||||||
import org.apache.spark.serializer.SerializerInstance;
|
import org.apache.spark.serializer.SerializerInstance;
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.*;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
|
||||||
import org.junit.jupiter.api.Disabled;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.nd4j.common.primitives.*;
|
import org.nd4j.common.primitives.*;
|
||||||
|
import org.nd4j.common.resources.Downloader;
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
import org.nd4j.common.tests.BaseND4JTest;
|
||||||
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import scala.Tuple2;
|
import scala.Tuple2;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.net.URI;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
@Disabled("Ignoring due to flaky nature of tests")
|
@Slf4j
|
||||||
|
@Tag(TagNames.SPARK)
|
||||||
|
@Tag(TagNames.DIST_SYSTEMS)
|
||||||
public class TestNd4jKryoSerialization extends BaseND4JTest {
|
public class TestNd4jKryoSerialization extends BaseND4JTest {
|
||||||
|
|
||||||
private JavaSparkContext sc;
|
private JavaSparkContext sc;
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
@SneakyThrows
|
||||||
|
public static void beforeAll() {
|
||||||
|
if(Platform.isWindows()) {
|
||||||
|
File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp");
|
||||||
|
File binDir = new File(hadoopHome,"bin");
|
||||||
|
if(!binDir.exists())
|
||||||
|
binDir.mkdirs();
|
||||||
|
File outputFile = new File(binDir,"winutils.exe");
|
||||||
|
if(!outputFile.exists()) {
|
||||||
|
log.info("Fixing spark for windows");
|
||||||
|
Downloader.download("winutils.exe",
|
||||||
|
URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(),
|
||||||
|
outputFile,"db24b404d2331a1bec7443336a5171f1",3);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
public void before() {
|
public void before() {
|
||||||
SparkConf sparkConf = new SparkConf();
|
SparkConf sparkConf = new SparkConf();
|
||||||
|
|
|
@ -49,6 +49,12 @@
|
||||||
|
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-common-tests</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>samediff-import-api</artifactId>
|
<artifactId>samediff-import-api</artifactId>
|
||||||
|
|
|
@ -40,6 +40,7 @@ import org.apache.commons.io.FileUtils
|
||||||
import org.junit.jupiter.api.Disabled
|
import org.junit.jupiter.api.Disabled
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
import org.nd4j.common.resources.Downloader
|
import org.nd4j.common.resources.Downloader
|
||||||
|
import org.nd4j.common.tests.tags.ExpensiveTest
|
||||||
import org.nd4j.common.util.ArchiveUtils
|
import org.nd4j.common.util.ArchiveUtils
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.samediff.frameworkimport.onnx.importer.OnnxFrameworkImporter
|
import org.nd4j.samediff.frameworkimport.onnx.importer.OnnxFrameworkImporter
|
||||||
|
@ -50,7 +51,7 @@ import java.io.File
|
||||||
import java.net.URI
|
import java.net.URI
|
||||||
|
|
||||||
data class InputDataset(val dataSetIndex: Int,val inputPaths: List<String>,val outputPaths: List<String>)
|
data class InputDataset(val dataSetIndex: Int,val inputPaths: List<String>,val outputPaths: List<String>)
|
||||||
@Disabled
|
@ExpensiveTest
|
||||||
class TestPretrainedModels {
|
class TestPretrainedModels {
|
||||||
|
|
||||||
val modelBaseUrl = "https://media.githubusercontent.com/media/onnx/models/master"
|
val modelBaseUrl = "https://media.githubusercontent.com/media/onnx/models/master"
|
||||||
|
|
|
@ -22,6 +22,7 @@ package org.nd4j.imports.tfgraphs;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
import org.junit.jupiter.api.Tag;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -33,6 +34,7 @@ import org.nd4j.autodiff.samediff.transform.OpPredicate;
|
||||||
import org.nd4j.autodiff.samediff.transform.SubGraph;
|
import org.nd4j.autodiff.samediff.transform.SubGraph;
|
||||||
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
|
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
|
||||||
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
|
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
|
||||||
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.graph.ui.LogFileWriter;
|
import org.nd4j.graph.ui.LogFileWriter;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.imports.tensorflow.TFImportOverride;
|
import org.nd4j.imports.tensorflow.TFImportOverride;
|
||||||
|
@ -55,7 +57,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Disabled("AB 2019/05/21 - JVM Crash on linux-x86_64-cuda-9.2, linux-ppc64le-cpu - Issue #7657")
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public class BERTGraphTest extends BaseNd4jTestWithBackends {
|
public class BERTGraphTest extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ public class CustomOpTests extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPad(Nd4jBackend backend){
|
public void testPad(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray in = Nd4j.create(DataType.FLOAT, 1, 28, 28, 264);
|
INDArray in = Nd4j.create(DataType.FLOAT, 1, 28, 28, 264);
|
||||||
INDArray pad = Nd4j.createFromArray(new int[][]{{0,0},{0,1},{0,1},{0,0}});
|
INDArray pad = Nd4j.createFromArray(new int[][]{{0,0},{0,1},{0,1},{0,0}});
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.params.provider.Arguments;
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
|
|
||||||
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
|
@ -41,7 +42,8 @@ import java.util.stream.Stream;
|
||||||
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Disabled("AB 2019/05/21 - JVM Crashes - Issue #7657")
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests
|
public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests
|
||||||
|
|
||||||
private Map<String, INDArray> inputs;
|
private Map<String, INDArray> inputs;
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.junit.jupiter.api.*;
|
||||||
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.Arguments;
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
|
@ -38,6 +39,8 @@ import java.util.*;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests
|
public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.junit.jupiter.api.io.TempDir;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.Arguments;
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
import org.nd4j.common.tests.tags.TagNames;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -44,7 +45,8 @@ import java.util.Map;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
|
||||||
@Disabled
|
@Tag(TagNames.LONG_TEST)
|
||||||
|
@Tag(TagNames.LARGE_RESOURCES)
|
||||||
public class TFGraphTestList {
|
public class TFGraphTestList {
|
||||||
|
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue