Add large test tags, ensure that small runs finish, get rid of test timeouts
This commit is contained in:
		
							parent
							
								
									652b854083
								
							
						
					
					
						commit
						3e60302e8c
					
				| @ -31,7 +31,7 @@ jobs: | |||||||
|               protoc --version |               protoc --version | ||||||
|               cd dl4j-test-resources-master && mvn clean install -DskipTests && cd .. |               cd dl4j-test-resources-master && mvn clean install -DskipTests && cd .. | ||||||
|               export OMP_NUM_THREADS=1 |               export OMP_NUM_THREADS=1 | ||||||
|               mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test |               mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j"  -DexcludedGroups="long-running-tests,large-resources" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test --fail-never | ||||||
| 
 | 
 | ||||||
|   windows-x86_64: |   windows-x86_64: | ||||||
|     runs-on: windows-2019 |     runs-on: windows-2019 | ||||||
| @ -44,7 +44,7 @@ jobs: | |||||||
|         run: | |         run: | | ||||||
|               set "PATH=C:\msys64\usr\bin;%PATH%" |               set "PATH=C:\msys64\usr\bin;%PATH%" | ||||||
|               export OMP_NUM_THREADS=1 |               export OMP_NUM_THREADS=1 | ||||||
|               mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j"  -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test |               mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j"  -DexcludedGroups="long-running-tests,large-resources" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test --fail-never | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -22,5 +22,6 @@ jobs: | |||||||
|           cmake --version |           cmake --version | ||||||
|           protoc --version |           protoc --version | ||||||
|           export OMP_NUM_THREADS=1 |           export OMP_NUM_THREADS=1 | ||||||
|           mvn   -DexcludedGroups=long-running-tests -DskipTestResourceEnforcement=true -Ptestresources  -Pintegration-tests  -Pnd4j-tests-cpu   clean test |           mvn   -DexcludedGroups="long-running-tests,large-resources" -DskipTestResourceEnforcement=true -Ptestresources  -Pintegration-tests  -Pnd4j-tests-cpu   clean test | ||||||
|  |           mvn -Ptestresources -Pnd4j-tests-cpu  -Dtest.offheap.size=14g -Dtest.heap.size=6g  clean test | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -34,5 +34,5 @@ jobs: | |||||||
|           cmake --version |           cmake --version | ||||||
|           protoc --version |           protoc --version | ||||||
|           export OMP_NUM_THREADS=1 |           export OMP_NUM_THREADS=1 | ||||||
|           mvn   -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j"  -Pnd4j-tests-cpu --also-make  clean test |           mvn   -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j"   -DexcludedGroups="long-running-tests,large-resources" -Pnd4j-tests-cpu --also-make  clean test --fail-never | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -34,5 +34,6 @@ jobs: | |||||||
|           cmake --version |           cmake --version | ||||||
|           protoc --version |           protoc --version | ||||||
|           export OMP_NUM_THREADS=1 |           export OMP_NUM_THREADS=1 | ||||||
|           mvn  -DexcludedGroups=long-running-tests  -DskipTestResourceEnforcement=true -Ptestresources  -Pintegration-tests  -Pnd4j-tests-cuda   clean test |           mvn  -DexcludedGroups="long-running-tests,large-resources"  -DskipTestResourceEnforcement=true -Ptestresources  -Pintegration-tests  -Pnd4j-tests-cuda   clean test --fail-never | ||||||
|  |           mvn -Ptestresources -Pnd4j-tests-cuda  -Dtest.offheap.size=14g -Dtest.heap.size=6g  clean test --fail-never | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -35,5 +35,5 @@ jobs: | |||||||
|           protoc --version |           protoc --version | ||||||
|           bash ./change-cuda-versions.sh 11.2 |           bash ./change-cuda-versions.sh 11.2 | ||||||
|           export OMP_NUM_THREADS=1 |           export OMP_NUM_THREADS=1 | ||||||
|           mvn  -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-cuda-11.2,:samediff-import,:libnd4j" -Dlibnd4j.helper=cudnn   -Ptest-nd4j-cuda  --also-make -Dlibnd4j.chip=cuda clean test |           mvn  -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-cuda-11.2,:samediff-import,:libnd4j" -Dlibnd4j.helper=cudnn   -Ptest-nd4j-cuda  --also-make -Dlibnd4j.chip=cuda clean test --fail-never | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -42,6 +42,17 @@ A few kinds of tags exist: | |||||||
| 7. RNG: (rng) for RNG related tests | 7. RNG: (rng) for RNG related tests | ||||||
| 8. Samediff:(samediff) samediff related tests | 8. Samediff:(samediff) samediff related tests | ||||||
| 9. Training related functionality | 9. Training related functionality | ||||||
|  | 10. long-running-tests: The longer running tests that take a longer execution time | ||||||
|  | 11. large-resources: tests requiring a large amount of ram/cpu (>= 2g up to 16g) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | New maven properties for maven surefire: | ||||||
|  | test.offheap.size: tunes off heap size for javacpp | ||||||
|  | test.heap.size: tunes heap size of test jvms | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | Auto tuning the number of CPU cores for tests relative to the number of CPUs present | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| ## Consequences | ## Consequences | ||||||
|  | |||||||
| @ -58,6 +58,7 @@ import java.io.File; | |||||||
| import java.io.FileInputStream; | import java.io.FileInputStream; | ||||||
| import java.io.IOException; | import java.io.IOException; | ||||||
| import java.io.OutputStream; | import java.io.OutputStream; | ||||||
|  | import java.nio.Buffer; | ||||||
| import java.nio.ByteBuffer; | import java.nio.ByteBuffer; | ||||||
| import java.nio.ByteOrder; | import java.nio.ByteOrder; | ||||||
| import java.util.*; | import java.util.*; | ||||||
| @ -171,7 +172,8 @@ public class ArrowConverter { | |||||||
|         ByteBuffer direct = ByteBuffer.allocateDirect(fieldVector.getDataBuffer().capacity()); |         ByteBuffer direct = ByteBuffer.allocateDirect(fieldVector.getDataBuffer().capacity()); | ||||||
|         direct.order(ByteOrder.nativeOrder()); |         direct.order(ByteOrder.nativeOrder()); | ||||||
|         fieldVector.getDataBuffer().getBytes(0,direct); |         fieldVector.getDataBuffer().getBytes(0,direct); | ||||||
|         direct.rewind(); |         Buffer buffer1 = (Buffer) direct; | ||||||
|  |         buffer1.rewind(); | ||||||
|         switch(type) { |         switch(type) { | ||||||
|             case Integer: |             case Integer: | ||||||
|                 buffer = Nd4j.createBuffer(direct, DataType.INT,cols,0); |                 buffer = Nd4j.createBuffer(direct, DataType.INT,cols,0); | ||||||
|  | |||||||
| @ -119,6 +119,7 @@ | |||||||
|                 <groupId>org.apache.maven.plugins</groupId> |                 <groupId>org.apache.maven.plugins</groupId> | ||||||
|                 <artifactId>maven-surefire-plugin</artifactId> |                 <artifactId>maven-surefire-plugin</artifactId> | ||||||
|                 <configuration> |                 <configuration> | ||||||
|  |                     <forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/> | ||||||
|                     <classpathDependencyExcludes> |                     <classpathDependencyExcludes> | ||||||
|                         <classpathDependencyExclude>com.google.android:android |                         <classpathDependencyExclude>com.google.android:android | ||||||
|                         </classpathDependencyExclude> |                         </classpathDependencyExclude> | ||||||
|  | |||||||
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							| @ -19,31 +19,26 @@ | |||||||
|  */ |  */ | ||||||
| package org.deeplearning4j.datasets; | package org.deeplearning4j.datasets; | ||||||
| 
 | 
 | ||||||
| import org.apache.commons.io.FileUtils; |  | ||||||
| import org.deeplearning4j.BaseDL4JTest; | import org.deeplearning4j.BaseDL4JTest; | ||||||
| import org.deeplearning4j.datasets.base.MnistFetcher; |  | ||||||
| import org.deeplearning4j.common.resources.DL4JResources; | import org.deeplearning4j.common.resources.DL4JResources; | ||||||
|  | import org.deeplearning4j.datasets.base.MnistFetcher; | ||||||
| import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | ||||||
| import org.junit.jupiter.api.*; | import org.junit.jupiter.api.*; | ||||||
| import org.junit.jupiter.api.io.TempDir; | import org.junit.jupiter.api.io.TempDir; | ||||||
| 
 |  | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| import org.nd4j.common.tests.tags.TagNames; | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; | import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; | ||||||
| import org.nd4j.linalg.dataset.DataSet; | import org.nd4j.linalg.dataset.DataSet; | ||||||
| import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| import org.nd4j.linalg.indexing.conditions.Conditions; | import org.nd4j.linalg.indexing.conditions.Conditions; | ||||||
|  | 
 | ||||||
| import java.io.File; | import java.io.File; | ||||||
| import java.nio.file.Path; | import java.nio.file.Path; | ||||||
| import java.util.HashSet; | import java.util.HashSet; | ||||||
| import java.util.Set; | import java.util.Set; | ||||||
| import static org.junit.jupiter.api.Assertions.assertEquals; |  | ||||||
| import static org.junit.jupiter.api.Assertions.assertFalse; |  | ||||||
| import static org.junit.jupiter.api.Assertions.assertTrue; |  | ||||||
| 
 | 
 | ||||||
| import org.junit.jupiter.api.extension.ExtendWith; | import static org.junit.jupiter.api.Assertions.*; | ||||||
| 
 | 
 | ||||||
| @DisplayName("Mnist Fetcher Test") | @DisplayName("Mnist Fetcher Test") | ||||||
| @NativeTag | @NativeTag | ||||||
| @ -65,6 +60,9 @@ class MnistFetcherTest extends BaseDL4JTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     @DisplayName("Test Mnist") |     @DisplayName("Test Mnist") | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|  |     @Tag(TagNames.FILE_IO) | ||||||
|     void testMnist() throws Exception { |     void testMnist() throws Exception { | ||||||
|         MnistDataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1); |         MnistDataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1); | ||||||
|         int count = 0; |         int count = 0; | ||||||
| @ -91,6 +89,9 @@ class MnistFetcherTest extends BaseDL4JTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     @DisplayName("Test Mnist Data Fetcher") |     @DisplayName("Test Mnist Data Fetcher") | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|  |     @Tag(TagNames.FILE_IO) | ||||||
|     void testMnistDataFetcher() throws Exception { |     void testMnistDataFetcher() throws Exception { | ||||||
|         MnistFetcher mnistFetcher = new MnistFetcher(); |         MnistFetcher mnistFetcher = new MnistFetcher(); | ||||||
|         File mnistDir = mnistFetcher.downloadAndUntar(); |         File mnistDir = mnistFetcher.downloadAndUntar(); | ||||||
| @ -99,6 +100,9 @@ class MnistFetcherTest extends BaseDL4JTest { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|  |     @Tag(TagNames.FILE_IO) | ||||||
|     public void testMnistSubset() throws Exception { |     public void testMnistSubset() throws Exception { | ||||||
|         final int numExamples = 100; |         final int numExamples = 100; | ||||||
|         MnistDataSetIterator iter1 = new MnistDataSetIterator(10, numExamples, false, true, true, 123); |         MnistDataSetIterator iter1 = new MnistDataSetIterator(10, numExamples, false, true, true, 123); | ||||||
| @ -144,6 +148,9 @@ class MnistFetcherTest extends BaseDL4JTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     @DisplayName("Test Subset Repeatability") |     @DisplayName("Test Subset Repeatability") | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|  |     @Tag(TagNames.FILE_IO) | ||||||
|     void testSubsetRepeatability() throws Exception { |     void testSubsetRepeatability() throws Exception { | ||||||
|         MnistDataSetIterator it = new MnistDataSetIterator(1, 1, false, false, true, 0); |         MnistDataSetIterator it = new MnistDataSetIterator(1, 1, false, false, true, 0); | ||||||
|         DataSet d1 = it.next(); |         DataSet d1 = it.next(); | ||||||
|  | |||||||
| @ -51,6 +51,7 @@ public class TestEmnistDataSetIterator extends BaseDL4JTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     @Tag(TagNames.LONG_TEST) |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|     public void testEmnistDataSetIterator() throws Exception { |     public void testEmnistDataSetIterator() throws Exception { | ||||||
|         int batchSize = 128; |         int batchSize = 128; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -63,6 +63,8 @@ import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |||||||
| import org.deeplearning4j.util.ModelSerializer; | import org.deeplearning4j.util.ModelSerializer; | ||||||
| import org.junit.jupiter.api.*; | import org.junit.jupiter.api.*; | ||||||
| import org.junit.jupiter.api.io.TempDir; | import org.junit.jupiter.api.io.TempDir; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| import org.nd4j.common.tests.tags.TagNames; | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.activations.Activation; | import org.nd4j.linalg.activations.Activation; | ||||||
| @ -1717,8 +1719,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { | |||||||
|         MultiLayerTest.CheckModelsListener listener = new MultiLayerTest.CheckModelsListener(); |         MultiLayerTest.CheckModelsListener listener = new MultiLayerTest.CheckModelsListener(); | ||||||
|         net.setListeners(listener); |         net.setListeners(listener); | ||||||
| 
 | 
 | ||||||
|         INDArray f = Nd4j.create(1,10); |         INDArray f = Nd4j.create(DataType.DOUBLE,1,10); | ||||||
|         INDArray l = Nd4j.create(1,10); |         INDArray l = Nd4j.create(DataType.DOUBLE,1,10); | ||||||
|         DataSet ds = new DataSet(f,l); |         DataSet ds = new DataSet(f,l); | ||||||
|         MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f,l); |         MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f,l); | ||||||
| 
 | 
 | ||||||
| @ -2117,9 +2119,10 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|  |     @Execution(ExecutionMode.SAME_THREAD) | ||||||
|  |     @Tag(TagNames.NEEDS_VERIFY) | ||||||
|  |     @Disabled | ||||||
|     public void testCompGraphInputReuse() { |     public void testCompGraphInputReuse() { | ||||||
|         Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); |  | ||||||
| 
 |  | ||||||
|         int inputSize = 5; |         int inputSize = 5; | ||||||
|         int outputSize = 6; |         int outputSize = 6; | ||||||
|         int layerSize = 3; |         int layerSize = 3; | ||||||
| @ -2134,7 +2137,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { | |||||||
|                 .setOutputs("out") |                 .setOutputs("out") | ||||||
|                 .addLayer("0",new DenseLayer.Builder().nIn(inputSize).nOut(layerSize).build(),"in") |                 .addLayer("0",new DenseLayer.Builder().nIn(inputSize).nOut(layerSize).build(),"in") | ||||||
|                 .addVertex("combine", new MergeVertex(), "0", "0", "0") |                 .addVertex("combine", new MergeVertex(), "0", "0", "0") | ||||||
|                 .addLayer("out",new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(3*layerSize).nOut(outputSize) |                 .addLayer("out",new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(3*layerSize) | ||||||
|  |                         .nOut(outputSize) | ||||||
|                         .activation(Activation.SIGMOID).build(),"combine") |                         .activation(Activation.SIGMOID).build(),"combine") | ||||||
|                 .build(); |                 .build(); | ||||||
| 
 | 
 | ||||||
| @ -2143,8 +2147,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|         int dataSize = 11; |         int dataSize = 11; | ||||||
|         INDArray features = Nd4j.rand(new int[] {dataSize, inputSize}); |         INDArray features = Nd4j.rand(DataType.DOUBLE,new int[] {dataSize, inputSize}); | ||||||
|         INDArray labels = Nd4j.rand(new int[] {dataSize, outputSize}); |         INDArray labels = Nd4j.rand(DataType.DOUBLE,new int[] {dataSize, outputSize}); | ||||||
| 
 | 
 | ||||||
|         boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{features}) |         boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{features}) | ||||||
|                 .labels(new INDArray[]{labels})); |                 .labels(new INDArray[]{labels})); | ||||||
|  | |||||||
| @ -23,8 +23,11 @@ import org.deeplearning4j.BaseDL4JTest; | |||||||
| import org.deeplearning4j.nn.conf.distribution.*; | import org.deeplearning4j.nn.conf.distribution.*; | ||||||
| import org.deeplearning4j.nn.conf.serde.JsonMappers; | import org.deeplearning4j.nn.conf.serde.JsonMappers; | ||||||
| import org.junit.jupiter.api.*; | import org.junit.jupiter.api.*; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| import org.nd4j.common.tests.tags.TagNames; | import org.nd4j.common.tests.tags.TagNames; | ||||||
|  | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.rng.Random; | import org.nd4j.linalg.api.rng.Random; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| @ -69,14 +72,19 @@ class LegacyWeightInitTest extends BaseDL4JTest { | |||||||
|         final long[] shape = { 5, 5 }; |         final long[] shape = { 5, 5 }; | ||||||
|         final long fanIn = shape[0]; |         final long fanIn = shape[0]; | ||||||
|         final long fanOut = shape[1]; |         final long fanOut = shape[1]; | ||||||
|         final INDArray inLegacy = Nd4j.create(fanIn * fanOut); |         final INDArray inLegacy = Nd4j.create(DataType.DOUBLE,fanIn * fanOut); | ||||||
|         final INDArray inTest = inLegacy.dup(); |         final INDArray inTest = inLegacy.dup(); | ||||||
|         for (WeightInit legacyWi : WeightInit.values()) { |         for (WeightInit legacyWi : WeightInit.values()) { | ||||||
|             if (legacyWi != WeightInit.DISTRIBUTION) { |             if (legacyWi != WeightInit.DISTRIBUTION) { | ||||||
|                 Nd4j.getRandom().setSeed(SEED); |                 Nd4j.getRandom().setSeed(SEED); | ||||||
|                 final INDArray expected = WeightInitUtil.initWeights(fanIn, fanOut, shape, legacyWi, null, inLegacy); |                 final INDArray expected = WeightInitUtil. | ||||||
|  |                         initWeights(fanIn, fanOut, shape, legacyWi, null, inLegacy) | ||||||
|  |                         .castTo(DataType.DOUBLE); | ||||||
|                 Nd4j.getRandom().setSeed(SEED); |                 Nd4j.getRandom().setSeed(SEED); | ||||||
|                 final INDArray actual = legacyWi.getWeightInitFunction().init(fanIn, fanOut, shape, WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER, inTest); |                 final INDArray actual = legacyWi.getWeightInitFunction() | ||||||
|  |                         .init(fanIn, fanOut, shape, | ||||||
|  |                                 WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER, inTest) | ||||||
|  |                         .castTo(DataType.DOUBLE); | ||||||
|                 assertArrayEquals(shape, actual.shape(),"Incorrect shape for " + legacyWi + "!"); |                 assertArrayEquals(shape, actual.shape(),"Incorrect shape for " + legacyWi + "!"); | ||||||
|                 assertEquals( expected, actual,"Incorrect weight initialization for " + legacyWi + "!"); |                 assertEquals( expected, actual,"Incorrect weight initialization for " + legacyWi + "!"); | ||||||
|             } |             } | ||||||
| @ -88,17 +96,24 @@ class LegacyWeightInitTest extends BaseDL4JTest { | |||||||
|      */ |      */ | ||||||
|     @Test |     @Test | ||||||
|     @DisplayName("Init Params From Distribution") |     @DisplayName("Init Params From Distribution") | ||||||
|  |     @Execution(ExecutionMode.SAME_THREAD) | ||||||
|  |     @Disabled(TagNames.NEEDS_VERIFY) | ||||||
|     void initParamsFromDistribution() { |     void initParamsFromDistribution() { | ||||||
|         // To make identity happy |         // To make identity happy | ||||||
|         final long[] shape = { 3, 7 }; |         final long[] shape = { 3, 7 }; | ||||||
|         final long fanIn = shape[0]; |         final long fanIn = shape[0]; | ||||||
|         final long fanOut = shape[1]; |         final long fanOut = shape[1]; | ||||||
|         final INDArray inLegacy = Nd4j.create(fanIn * fanOut); |         final INDArray inLegacy = Nd4j.create(DataType.DOUBLE,fanIn * fanOut); | ||||||
|         final INDArray inTest = inLegacy.dup(); |         final INDArray inTest = inLegacy.dup(); | ||||||
|         for (Distribution dist : distributions) { |         for (Distribution dist : distributions) { | ||||||
|             Nd4j.getRandom().setSeed(SEED); |             Nd4j.getRandom().setSeed(SEED); | ||||||
|             final INDArray expected = WeightInitUtil.initWeights(fanIn, fanOut, shape, WeightInit.DISTRIBUTION, Distributions.createDistribution(dist), inLegacy); |             final INDArray expected = WeightInitUtil | ||||||
|             final INDArray actual = new WeightInitDistribution(dist).init(fanIn, fanOut, shape, WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER, inTest); |                     .initWeights(fanIn, fanOut, shape, WeightInit.DISTRIBUTION, | ||||||
|  |                             Distributions.createDistribution(dist), inLegacy) | ||||||
|  |                     .castTo(DataType.DOUBLE); | ||||||
|  |             final INDArray actual = new WeightInitDistribution(dist) | ||||||
|  |                     .init(fanIn, fanOut, shape, WeightInitUtil.DEFAULT_WEIGHT_INIT_ORDER, | ||||||
|  |                             inTest).castTo(DataType.DOUBLE); | ||||||
|             assertArrayEquals(shape, actual.shape(),"Incorrect shape for " + dist.getClass().getSimpleName() + "!"); |             assertArrayEquals(shape, actual.shape(),"Incorrect shape for " + dist.getClass().getSimpleName() + "!"); | ||||||
|             assertEquals( expected, actual,"Incorrect weight initialization for " + dist.getClass().getSimpleName() + "!"); |             assertEquals( expected, actual,"Incorrect weight initialization for " + dist.getClass().getSimpleName() + "!"); | ||||||
|         } |         } | ||||||
|  | |||||||
| @ -34,6 +34,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |||||||
| import org.deeplearning4j.nn.weights.WeightInit; | import org.deeplearning4j.nn.weights.WeightInit; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| import org.nd4j.common.tests.tags.TagNames; | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.activations.Activation; | import org.nd4j.linalg.activations.Activation; | ||||||
| @ -56,42 +58,41 @@ public class RandomTests extends BaseDL4JTest { | |||||||
|      * |      * | ||||||
|      * @throws Exception |      * @throws Exception | ||||||
|      */ |      */ | ||||||
|     @Test |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|  |     @Execution(ExecutionMode.SAME_THREAD) | ||||||
|     public void testModelInitialParamsEquality1() throws Exception { |     public void testModelInitialParamsEquality1() throws Exception { | ||||||
|         final List<Model> models = new CopyOnWriteArrayList<>(); |         final List<Model> models = new CopyOnWriteArrayList<>(); | ||||||
| 
 | 
 | ||||||
|         for (int i = 0; i < 4; i++) { |         for (int i = 0; i < 4; i++) { | ||||||
|             Thread thread = new Thread(new Runnable() { |             Thread thread = new Thread(() -> { | ||||||
|                 @Override |                 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(119) // Training iterations as above | ||||||
|                 public void run() { |                         .l2(0.0005) | ||||||
|                     MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(119) // Training iterations as above |                         //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75) | ||||||
|                                     .l2(0.0005) |                         .weightInit(WeightInit.XAVIER) | ||||||
|                                     //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75) |                         .updater(new Nesterovs(0.01, 0.9)) | ||||||
|                                     .weightInit(WeightInit.XAVIER) |                         .trainingWorkspaceMode(WorkspaceMode.ENABLED).list() | ||||||
|                                     .updater(new Nesterovs(0.01, 0.9)) |                         .layer(0, new ConvolutionLayer.Builder(5, 5) | ||||||
|                                     .trainingWorkspaceMode(WorkspaceMode.ENABLED).list() |                                 //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied | ||||||
|                                     .layer(0, new ConvolutionLayer.Builder(5, 5) |                                 .nIn(1).stride(1, 1).nOut(20).activation(Activation.IDENTITY) | ||||||
|                                                     //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied |                                 .build()) | ||||||
|                                                     .nIn(1).stride(1, 1).nOut(20).activation(Activation.IDENTITY) |                         .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | ||||||
|                                                     .build()) |                                 .kernelSize(2, 2).stride(2, 2).build()) | ||||||
|                                     .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) |                         .layer(2, new ConvolutionLayer.Builder(5, 5) | ||||||
|                                                     .kernelSize(2, 2).stride(2, 2).build()) |                                 //Note that nIn need not be specified in later layers | ||||||
|                                     .layer(2, new ConvolutionLayer.Builder(5, 5) |                                 .stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()) | ||||||
|                                                     //Note that nIn need not be specified in later layers |                         .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | ||||||
|                                                     .stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()) |                                 .kernelSize(2, 2).stride(2, 2).build()) | ||||||
|                                     .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) |                         .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) | ||||||
|                                                     .kernelSize(2, 2).stride(2, 2).build()) |                         .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | ||||||
|                                     .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) |                                 .nOut(10).activation(Activation.SOFTMAX).build()) | ||||||
|                                     .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) |                         .setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note below | ||||||
|                                                     .nOut(10).activation(Activation.SOFTMAX).build()) |                         .build(); | ||||||
|                                     .setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note below |  | ||||||
|                                     .build(); |  | ||||||
| 
 | 
 | ||||||
|                     MultiLayerNetwork network = new MultiLayerNetwork(conf); |                 MultiLayerNetwork network = new MultiLayerNetwork(conf); | ||||||
|                     network.init(); |                 network.init(); | ||||||
| 
 | 
 | ||||||
|                     models.add(network); |                 models.add(network); | ||||||
|                 } |  | ||||||
|             }); |             }); | ||||||
| 
 | 
 | ||||||
|             thread.start(); |             thread.start(); | ||||||
| @ -111,12 +112,12 @@ public class RandomTests extends BaseDL4JTest { | |||||||
|         Nd4j.getRandom().setSeed(12345); |         Nd4j.getRandom().setSeed(12345); | ||||||
| 
 | 
 | ||||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).activation(Activation.TANH) |         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).activation(Activation.TANH) | ||||||
|                         .weightInit(WeightInit.XAVIER).list() |                 .weightInit(WeightInit.XAVIER).list() | ||||||
|                         .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) |                 .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) | ||||||
|                         .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(2, |                 .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(2, | ||||||
|                                         new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) |                         new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) | ||||||
|                                                         .activation(Activation.SOFTMAX).nIn(10).nOut(10).build()) |                                 .activation(Activation.SOFTMAX).nIn(10).nOut(10).build()) | ||||||
|                         .build(); |                 .build(); | ||||||
| 
 | 
 | ||||||
|         String json = conf.toJson(); |         String json = conf.toJson(); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -47,8 +47,9 @@ | |||||||
|                     <configuration> |                     <configuration> | ||||||
|                         <forkCount>${cpu.core.count}</forkCount> |                         <forkCount>${cpu.core.count}</forkCount> | ||||||
|                         <reuseForks>false</reuseForks> |                         <reuseForks>false</reuseForks> | ||||||
|  |                         <forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/> | ||||||
|                         <argLine>-Ddtype=float -Dfile.encoding=UTF-8 |                         <argLine>-Ddtype=float -Dfile.encoding=UTF-8 | ||||||
|                             -Dtest.solr.allowed.securerandom=NativePRNG |                             -Dtest.solr.allowed.securerandom=NativePRNG -Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size} | ||||||
|                         </argLine> |                         </argLine> | ||||||
|                         <includes> |                         <includes> | ||||||
|                             <!-- Default setting only runs tests that start/end with "Test" --> |                             <!-- Default setting only runs tests that start/end with "Test" --> | ||||||
|  | |||||||
| @ -48,6 +48,8 @@ import org.junit.jupiter.api.extension.ExtendWith; | |||||||
| @DisplayName("Tuple Stream Data Set Iterator Test") | @DisplayName("Tuple Stream Data Set Iterator Test") | ||||||
| @Tag(TagNames.SOLR) | @Tag(TagNames.SOLR) | ||||||
| @Tag(TagNames.DIST_SYSTEMS) | @Tag(TagNames.DIST_SYSTEMS) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
|  | @Tag(TagNames.LONG_TEST) | ||||||
| class TupleStreamDataSetIteratorTest extends SolrCloudTestCase { | class TupleStreamDataSetIteratorTest extends SolrCloudTestCase { | ||||||
| 
 | 
 | ||||||
|     static { |     static { | ||||||
|  | |||||||
| @ -41,7 +41,8 @@ | |||||||
|                     <groupId>org.apache.maven.plugins</groupId> |                     <groupId>org.apache.maven.plugins</groupId> | ||||||
|                     <artifactId>maven-surefire-plugin</artifactId> |                     <artifactId>maven-surefire-plugin</artifactId> | ||||||
|                     <configuration> |                     <configuration> | ||||||
|                         <argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g |                         <forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/> | ||||||
|  |                         <argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx${test.heap.size} | ||||||
|                             -Dtest.solr.allowed.securerandom=NativePRNG |                             -Dtest.solr.allowed.securerandom=NativePRNG | ||||||
|                         </argLine> |                         </argLine> | ||||||
|                         <includes> |                         <includes> | ||||||
|  | |||||||
| @ -76,6 +76,8 @@ import static org.junit.jupiter.api.Assertions.*; | |||||||
| 
 | 
 | ||||||
| @Tag(TagNames.FILE_IO) | @Tag(TagNames.FILE_IO) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
|  | @Tag(TagNames.LONG_TEST) | ||||||
| public class SequenceVectorsTest extends BaseDL4JTest { | public class SequenceVectorsTest extends BaseDL4JTest { | ||||||
| 
 | 
 | ||||||
|     protected static final Logger logger = LoggerFactory.getLogger(SequenceVectorsTest.class); |     protected static final Logger logger = LoggerFactory.getLogger(SequenceVectorsTest.class); | ||||||
|  | |||||||
| @ -424,12 +424,7 @@ public class GradientCheckUtil { | |||||||
|             throw new IllegalArgumentException( |             throw new IllegalArgumentException( | ||||||
|                             "Invalid labels arrays: expect " + c.net.getNumOutputArrays() + " outputs"); |                             "Invalid labels arrays: expect " + c.net.getNumOutputArrays() + " outputs"); | ||||||
| 
 | 
 | ||||||
|         DataType dataType = DataTypeUtil.getDtypeFromContext(); |        | ||||||
|         if (dataType != DataType.DOUBLE) { |  | ||||||
|             throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (" |  | ||||||
|                             + "is: " + dataType + "). Double precision must be used for gradient checks. Set " |  | ||||||
|                             + "DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil"); |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|         DataType netDataType = c.net.getConfiguration().getDataType(); |         DataType netDataType = c.net.getConfiguration().getDataType(); | ||||||
|         if (netDataType != DataType.DOUBLE) { |         if (netDataType != DataType.DOUBLE) { | ||||||
|  | |||||||
| @ -21,6 +21,8 @@ | |||||||
| package org.deeplearning4j.spark.models.sequencevectors; | package org.deeplearning4j.spark.models.sequencevectors; | ||||||
| 
 | 
 | ||||||
| import com.sun.jna.Platform; | import com.sun.jna.Platform; | ||||||
|  | import lombok.SneakyThrows; | ||||||
|  | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.apache.spark.SparkConf; | import org.apache.spark.SparkConf; | ||||||
| import org.apache.spark.api.java.JavaRDD; | import org.apache.spark.api.java.JavaRDD; | ||||||
| import org.apache.spark.api.java.JavaSparkContext; | import org.apache.spark.api.java.JavaSparkContext; | ||||||
| @ -35,9 +37,12 @@ import org.deeplearning4j.spark.models.word2vec.SparkWord2VecTest; | |||||||
| import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | ||||||
| import org.junit.jupiter.api.*; | import org.junit.jupiter.api.*; | ||||||
| import org.nd4j.common.primitives.Counter; | import org.nd4j.common.primitives.Counter; | ||||||
|  | import org.nd4j.common.resources.Downloader; | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| import org.nd4j.common.tests.tags.TagNames; | import org.nd4j.common.tests.tags.TagNames; | ||||||
| 
 | 
 | ||||||
|  | import java.io.File; | ||||||
|  | import java.net.URI; | ||||||
| import java.util.ArrayList; | import java.util.ArrayList; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| 
 | 
 | ||||||
| @ -47,6 +52,7 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals; | |||||||
| @Tag(TagNames.SPARK) | @Tag(TagNames.SPARK) | ||||||
| @Tag(TagNames.DIST_SYSTEMS) | @Tag(TagNames.DIST_SYSTEMS) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Slf4j | ||||||
| public class SparkSequenceVectorsTest extends BaseDL4JTest { | public class SparkSequenceVectorsTest extends BaseDL4JTest { | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
| @ -57,6 +63,27 @@ public class SparkSequenceVectorsTest extends BaseDL4JTest { | |||||||
|     protected static List<Sequence<VocabWord>> sequencesCyclic; |     protected static List<Sequence<VocabWord>> sequencesCyclic; | ||||||
|     private JavaSparkContext sc; |     private JavaSparkContext sc; | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  |     @BeforeAll | ||||||
|  |     @SneakyThrows | ||||||
|  |     public static void beforeAll() { | ||||||
|  |         if(Platform.isWindows()) { | ||||||
|  |             File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp"); | ||||||
|  |             File binDir = new File(hadoopHome,"bin"); | ||||||
|  |             if(!binDir.exists()) | ||||||
|  |                 binDir.mkdirs(); | ||||||
|  |             File outputFile = new File(binDir,"winutils.exe"); | ||||||
|  |             if(!outputFile.exists()) { | ||||||
|  |                 log.info("Fixing spark for windows"); | ||||||
|  |                 Downloader.download("winutils.exe", | ||||||
|  |                         URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(), | ||||||
|  |                         outputFile,"db24b404d2331a1bec7443336a5171f1",3); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void setUp() throws Exception { |     public void setUp() throws Exception { | ||||||
|         if (sequencesCyclic == null) { |         if (sequencesCyclic == null) { | ||||||
|  | |||||||
| @ -20,6 +20,9 @@ | |||||||
| 
 | 
 | ||||||
| package org.deeplearning4j.spark.models.word2vec; | package org.deeplearning4j.spark.models.word2vec; | ||||||
| 
 | 
 | ||||||
|  | import com.sun.jna.Platform; | ||||||
|  | import lombok.SneakyThrows; | ||||||
|  | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.apache.spark.SparkConf; | import org.apache.spark.SparkConf; | ||||||
| import org.apache.spark.api.java.JavaRDD; | import org.apache.spark.api.java.JavaRDD; | ||||||
| import org.apache.spark.api.java.JavaSparkContext; | import org.apache.spark.api.java.JavaSparkContext; | ||||||
| @ -35,11 +38,14 @@ import org.deeplearning4j.spark.models.sequencevectors.export.SparkModelExporter | |||||||
| import org.deeplearning4j.spark.models.sequencevectors.learning.elements.SparkSkipGram; | import org.deeplearning4j.spark.models.sequencevectors.learning.elements.SparkSkipGram; | ||||||
| import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | ||||||
| import org.junit.jupiter.api.*; | import org.junit.jupiter.api.*; | ||||||
|  | import org.nd4j.common.resources.Downloader; | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| import org.nd4j.common.tests.tags.TagNames; | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; | import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; | ||||||
| 
 | 
 | ||||||
|  | import java.io.File; | ||||||
| import java.io.Serializable; | import java.io.Serializable; | ||||||
|  | import java.net.URI; | ||||||
| import java.util.ArrayList; | import java.util.ArrayList; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| 
 | 
 | ||||||
| @ -48,6 +54,7 @@ import static org.junit.jupiter.api.Assertions.*; | |||||||
| @Tag(TagNames.SPARK) | @Tag(TagNames.SPARK) | ||||||
| @Tag(TagNames.DIST_SYSTEMS) | @Tag(TagNames.DIST_SYSTEMS) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Slf4j | ||||||
| public class SparkWord2VecTest extends BaseDL4JTest { | public class SparkWord2VecTest extends BaseDL4JTest { | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
| @ -58,6 +65,27 @@ public class SparkWord2VecTest extends BaseDL4JTest { | |||||||
|     private static List<String> sentences; |     private static List<String> sentences; | ||||||
|     private JavaSparkContext sc; |     private JavaSparkContext sc; | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  |     @BeforeAll | ||||||
|  |     @SneakyThrows | ||||||
|  |     public static void beforeAll() { | ||||||
|  |         if(Platform.isWindows()) { | ||||||
|  |             File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp"); | ||||||
|  |             File binDir = new File(hadoopHome,"bin"); | ||||||
|  |             if(!binDir.exists()) | ||||||
|  |                 binDir.mkdirs(); | ||||||
|  |             File outputFile = new File(binDir,"winutils.exe"); | ||||||
|  |             if(!outputFile.exists()) { | ||||||
|  |                 log.info("Fixing spark for windows"); | ||||||
|  |                 Downloader.download("winutils.exe", | ||||||
|  |                         URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(), | ||||||
|  |                         outputFile,"db24b404d2331a1bec7443336a5171f1",3); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void setUp() throws Exception { |     public void setUp() throws Exception { | ||||||
|         if (sentences == null) { |         if (sentences == null) { | ||||||
|  | |||||||
| @ -21,11 +21,15 @@ | |||||||
| package org.deeplearning4j.spark.models.embeddings.word2vec; | package org.deeplearning4j.spark.models.embeddings.word2vec; | ||||||
| 
 | 
 | ||||||
| import com.sun.jna.Platform; | import com.sun.jna.Platform; | ||||||
|  | import lombok.SneakyThrows; | ||||||
|  | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.apache.spark.SparkConf; | import org.apache.spark.SparkConf; | ||||||
| import org.apache.spark.api.java.JavaRDD; | import org.apache.spark.api.java.JavaRDD; | ||||||
| import org.apache.spark.api.java.JavaSparkContext; | import org.apache.spark.api.java.JavaSparkContext; | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | import org.deeplearning4j.common.resources.DL4JResources; | ||||||
|  | import org.junit.jupiter.api.BeforeAll; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.io.TempDir; | import org.junit.jupiter.api.io.TempDir; | ||||||
| import org.nd4j.common.io.ClassPathResource; | import org.nd4j.common.io.ClassPathResource; | ||||||
| @ -41,11 +45,14 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFac | |||||||
| import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | ||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.Disabled; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.nd4j.common.resources.Downloader; | ||||||
|  | import org.nd4j.common.resources.strumpf.StrumpfResolver; | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| import org.nd4j.common.tests.tags.TagNames; | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| 
 | 
 | ||||||
| import java.io.File; | import java.io.File; | ||||||
|  | import java.net.URI; | ||||||
| import java.nio.file.Files; | import java.nio.file.Files; | ||||||
| import java.nio.file.Path; | import java.nio.file.Path; | ||||||
| import java.util.Arrays; | import java.util.Arrays; | ||||||
| @ -53,21 +60,37 @@ import java.util.Collection; | |||||||
| 
 | 
 | ||||||
| import static org.junit.jupiter.api.Assertions.*; | import static org.junit.jupiter.api.Assertions.*; | ||||||
| 
 | 
 | ||||||
| @Disabled |  | ||||||
| @Tag(TagNames.FILE_IO) | @Tag(TagNames.FILE_IO) | ||||||
| @Tag(TagNames.SPARK) | @Tag(TagNames.SPARK) | ||||||
| @Tag(TagNames.DIST_SYSTEMS) | @Tag(TagNames.DIST_SYSTEMS) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Slf4j | ||||||
|  | @Tag(TagNames.LONG_TEST) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
| public class Word2VecTest { | public class Word2VecTest { | ||||||
|  |     @BeforeAll | ||||||
|  |     @SneakyThrows | ||||||
|  |     public static void beforeAll() { | ||||||
|  |         if(Platform.isWindows()) { | ||||||
|  |             File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp"); | ||||||
|  |             File binDir = new File(hadoopHome,"bin"); | ||||||
|  |             if(!binDir.exists()) | ||||||
|  |                 binDir.mkdirs(); | ||||||
|  |             File outputFile = new File(binDir,"winutils.exe"); | ||||||
|  |             if(!outputFile.exists()) { | ||||||
|  |                 log.info("Fixing spark for windows"); | ||||||
|  |                 Downloader.download("winutils.exe", | ||||||
|  |                         URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(), | ||||||
|  |                         outputFile,"db24b404d2331a1bec7443336a5171f1",3); | ||||||
|  |             } | ||||||
| 
 | 
 | ||||||
|  |             System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testConcepts(@TempDir Path testDir) throws Exception { |     public void testConcepts(@TempDir Path testDir) throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         // These are all default values for word2vec |         // These are all default values for word2vec | ||||||
|         SparkConf sparkConf = new SparkConf().setMaster("local[8]") |         SparkConf sparkConf = new SparkConf().setMaster("local[8]") | ||||||
|                 .set("spark.driver.host", "localhost") |                 .set("spark.driver.host", "localhost") | ||||||
|  | |||||||
| @ -20,21 +20,50 @@ | |||||||
| 
 | 
 | ||||||
| package org.deeplearning4j.spark.text; | package org.deeplearning4j.spark.text; | ||||||
| 
 | 
 | ||||||
|  | import com.sun.jna.Platform; | ||||||
|  | import lombok.SneakyThrows; | ||||||
|  | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.apache.spark.SparkConf; | import org.apache.spark.SparkConf; | ||||||
| import org.apache.spark.api.java.JavaSparkContext; | import org.apache.spark.api.java.JavaSparkContext; | ||||||
| import org.deeplearning4j.BaseDL4JTest; | import org.deeplearning4j.BaseDL4JTest; | ||||||
| import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables; | import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables; | ||||||
| import org.junit.jupiter.api.AfterEach; | import org.junit.jupiter.api.AfterEach; | ||||||
|  | import org.junit.jupiter.api.BeforeAll; | ||||||
| import org.junit.jupiter.api.BeforeEach; | import org.junit.jupiter.api.BeforeEach; | ||||||
|  | import org.nd4j.common.resources.Downloader; | ||||||
| 
 | 
 | ||||||
|  | import java.io.File; | ||||||
| import java.io.Serializable; | import java.io.Serializable; | ||||||
| import java.lang.reflect.Field; | import java.lang.reflect.Field; | ||||||
|  | import java.net.URI; | ||||||
| import java.util.Collections; | import java.util.Collections; | ||||||
| import java.util.Map; | import java.util.Map; | ||||||
| 
 | @Slf4j | ||||||
| public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable { | public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable { | ||||||
|     protected transient JavaSparkContext sc; |     protected transient JavaSparkContext sc; | ||||||
| 
 | 
 | ||||||
|  |     @BeforeAll | ||||||
|  |     @SneakyThrows | ||||||
|  |     public static void beforeAll() { | ||||||
|  |         if(Platform.isWindows()) { | ||||||
|  |             File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp"); | ||||||
|  |             File binDir = new File(hadoopHome,"bin"); | ||||||
|  |             if(!binDir.exists()) | ||||||
|  |                 binDir.mkdirs(); | ||||||
|  |             File outputFile = new File(binDir,"winutils.exe"); | ||||||
|  |             if(!outputFile.exists()) { | ||||||
|  |                 log.info("Fixing spark for windows"); | ||||||
|  |                 Downloader.download("winutils.exe", | ||||||
|  |                         URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(), | ||||||
|  |                         outputFile,"db24b404d2331a1bec7443336a5171f1",3); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|     @Override |     @Override | ||||||
|     public long getTimeoutMilliseconds() { |     public long getTimeoutMilliseconds() { | ||||||
|         return 120000L; |         return 120000L; | ||||||
|  | |||||||
| @ -21,6 +21,7 @@ | |||||||
| package org.deeplearning4j.spark.text; | package org.deeplearning4j.spark.text; | ||||||
| 
 | 
 | ||||||
| import com.sun.jna.Platform; | import com.sun.jna.Platform; | ||||||
|  | import lombok.SneakyThrows; | ||||||
| import org.apache.spark.SparkConf; | import org.apache.spark.SparkConf; | ||||||
| import org.apache.spark.api.java.JavaPairRDD; | import org.apache.spark.api.java.JavaPairRDD; | ||||||
| import org.apache.spark.api.java.JavaRDD; | import org.apache.spark.api.java.JavaRDD; | ||||||
| @ -35,10 +36,8 @@ import org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec; | |||||||
| import org.deeplearning4j.spark.text.functions.CountCumSum; | import org.deeplearning4j.spark.text.functions.CountCumSum; | ||||||
| import org.deeplearning4j.spark.text.functions.TextPipeline; | import org.deeplearning4j.spark.text.functions.TextPipeline; | ||||||
| import org.deeplearning4j.text.stopwords.StopWords; | import org.deeplearning4j.text.stopwords.StopWords; | ||||||
| import org.junit.jupiter.api.BeforeEach; | import org.junit.jupiter.api.*; | ||||||
| import org.junit.jupiter.api.Disabled; | import org.nd4j.common.resources.Downloader; | ||||||
| import org.junit.jupiter.api.Tag; |  | ||||||
| import org.junit.jupiter.api.Test; |  | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| import org.nd4j.common.tests.tags.TagNames; | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| @ -48,6 +47,8 @@ import org.slf4j.Logger; | |||||||
| import org.slf4j.LoggerFactory; | import org.slf4j.LoggerFactory; | ||||||
| import scala.Tuple2; | import scala.Tuple2; | ||||||
| 
 | 
 | ||||||
|  | import java.io.File; | ||||||
|  | import java.net.URI; | ||||||
| import java.util.*; | import java.util.*; | ||||||
| import java.util.concurrent.atomic.AtomicLong; | import java.util.concurrent.atomic.AtomicLong; | ||||||
| 
 | 
 | ||||||
| @ -74,6 +75,26 @@ public class TextPipelineTest extends BaseSparkTest { | |||||||
|         return sc.parallelize(sentenceList, 2); |         return sc.parallelize(sentenceList, 2); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     @BeforeAll | ||||||
|  |     @SneakyThrows | ||||||
|  |     public static void beforeAll() { | ||||||
|  |         if(Platform.isWindows()) { | ||||||
|  |             File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp"); | ||||||
|  |             File binDir = new File(hadoopHome,"bin"); | ||||||
|  |             if(!binDir.exists()) | ||||||
|  |                 binDir.mkdirs(); | ||||||
|  |             File outputFile = new File(binDir,"winutils.exe"); | ||||||
|  |             if(!outputFile.exists()) { | ||||||
|  |                 log.info("Fixing spark for windows"); | ||||||
|  |                 Downloader.download("winutils.exe", | ||||||
|  |                         URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(), | ||||||
|  |                         outputFile,"db24b404d2331a1bec7443336a5171f1",3); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void before() throws Exception { |     public void before() throws Exception { | ||||||
|         conf = new SparkConf().setMaster("local[4]").setAppName("sparktest").set("spark.driver.host", "localhost"); |         conf = new SparkConf().setMaster("local[4]").setAppName("sparktest").set("spark.driver.host", "localhost"); | ||||||
| @ -102,10 +123,6 @@ public class TextPipelineTest extends BaseSparkTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testTokenizer() throws Exception { |     public void testTokenizer() throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         JavaSparkContext sc = getContext(); |         JavaSparkContext sc = getContext(); | ||||||
|         JavaRDD<String> corpusRDD = getCorpusRDD(sc); |         JavaRDD<String> corpusRDD = getCorpusRDD(sc); | ||||||
|         Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap()); |         Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap()); | ||||||
|  | |||||||
| @ -20,6 +20,9 @@ | |||||||
| 
 | 
 | ||||||
| package org.deeplearning4j.spark.parameterserver; | package org.deeplearning4j.spark.parameterserver; | ||||||
| 
 | 
 | ||||||
|  | import com.sun.jna.Platform; | ||||||
|  | import lombok.SneakyThrows; | ||||||
|  | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.apache.spark.SparkConf; | import org.apache.spark.SparkConf; | ||||||
| import org.apache.spark.api.java.JavaRDD; | import org.apache.spark.api.java.JavaRDD; | ||||||
| import org.apache.spark.api.java.JavaSparkContext; | import org.apache.spark.api.java.JavaSparkContext; | ||||||
| @ -29,7 +32,9 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |||||||
| import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; | import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; | ||||||
| import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; | import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; | ||||||
| import org.junit.jupiter.api.AfterEach; | import org.junit.jupiter.api.AfterEach; | ||||||
|  | import org.junit.jupiter.api.BeforeAll; | ||||||
| import org.junit.jupiter.api.BeforeEach; | import org.junit.jupiter.api.BeforeEach; | ||||||
|  | import org.nd4j.common.resources.Downloader; | ||||||
| import org.nd4j.linalg.activations.Activation; | import org.nd4j.linalg.activations.Activation; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.dataset.DataSet; | import org.nd4j.linalg.dataset.DataSet; | ||||||
| @ -37,12 +42,14 @@ import org.nd4j.linalg.factory.Nd4j; | |||||||
| import org.nd4j.linalg.learning.config.Nesterovs; | import org.nd4j.linalg.learning.config.Nesterovs; | ||||||
| import org.nd4j.linalg.lossfunctions.LossFunctions; | import org.nd4j.linalg.lossfunctions.LossFunctions; | ||||||
| 
 | 
 | ||||||
|  | import java.io.File; | ||||||
| import java.io.Serializable; | import java.io.Serializable; | ||||||
|  | import java.net.URI; | ||||||
| import java.util.ArrayList; | import java.util.ArrayList; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| import java.util.Random; | import java.util.Random; | ||||||
| 
 | 
 | ||||||
| 
 | @Slf4j | ||||||
| public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable { | public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable { | ||||||
|     protected transient JavaSparkContext sc; |     protected transient JavaSparkContext sc; | ||||||
|     protected transient INDArray labels; |     protected transient INDArray labels; | ||||||
| @ -60,6 +67,27 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable | |||||||
|         return 120000L; |         return 120000L; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  |     @BeforeAll | ||||||
|  |     @SneakyThrows | ||||||
|  |     public static void beforeAll() { | ||||||
|  |         if(Platform.isWindows()) { | ||||||
|  |             File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp"); | ||||||
|  |             File binDir = new File(hadoopHome,"bin"); | ||||||
|  |             if(!binDir.exists()) | ||||||
|  |                 binDir.mkdirs(); | ||||||
|  |             File outputFile = new File(binDir,"winutils.exe"); | ||||||
|  |             if(!outputFile.exists()) { | ||||||
|  |                 log.info("Fixing spark for windows"); | ||||||
|  |                 Downloader.download("winutils.exe", | ||||||
|  |                         URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(), | ||||||
|  |                         outputFile,"db24b404d2331a1bec7443336a5171f1",3); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void before() { |     public void before() { | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -40,10 +40,6 @@ public class SharedTrainingAccumulationFunctionTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testAccumulation1() throws Exception { |     public void testAccumulation1() throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         INDArray updates1 = Nd4j.create(1000).assign(1.0); |         INDArray updates1 = Nd4j.create(1000).assign(1.0); | ||||||
|         INDArray updates2 = Nd4j.create(1000).assign(2.0); |         INDArray updates2 = Nd4j.create(1000).assign(2.0); | ||||||
|         INDArray expUpdates = Nd4j.create(1000).assign(3.0); |         INDArray expUpdates = Nd4j.create(1000).assign(3.0); | ||||||
|  | |||||||
| @ -43,10 +43,6 @@ public class SharedTrainingAggregateFunctionTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testAggregate1() throws Exception { |     public void testAggregate1() throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         INDArray updates1 = Nd4j.create(1000).assign(1.0); |         INDArray updates1 = Nd4j.create(1000).assign(1.0); | ||||||
|         INDArray updates2 = Nd4j.create(1000).assign(2.0); |         INDArray updates2 = Nd4j.create(1000).assign(2.0); | ||||||
|         INDArray expUpdates = Nd4j.create(1000).assign(3.0); |         INDArray expUpdates = Nd4j.create(1000).assign(3.0); | ||||||
|  | |||||||
| @ -21,15 +21,21 @@ | |||||||
| package org.deeplearning4j.spark.parameterserver.iterators; | package org.deeplearning4j.spark.parameterserver.iterators; | ||||||
| 
 | 
 | ||||||
| import com.sun.jna.Platform; | import com.sun.jna.Platform; | ||||||
|  | import lombok.SneakyThrows; | ||||||
|  | import lombok.extern.slf4j.Slf4j; | ||||||
|  | import org.junit.jupiter.api.BeforeAll; | ||||||
| import org.junit.jupiter.api.BeforeEach; | import org.junit.jupiter.api.BeforeEach; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.nd4j.common.resources.Downloader; | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| import org.nd4j.common.tests.tags.TagNames; | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.dataset.DataSet; | import org.nd4j.linalg.dataset.DataSet; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| 
 | 
 | ||||||
|  | import java.io.File; | ||||||
|  | import java.net.URI; | ||||||
| import java.util.ArrayList; | import java.util.ArrayList; | ||||||
| import java.util.Iterator; | import java.util.Iterator; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| @ -39,17 +45,35 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | |||||||
| @Tag(TagNames.SPARK) | @Tag(TagNames.SPARK) | ||||||
| @Tag(TagNames.DIST_SYSTEMS) | @Tag(TagNames.DIST_SYSTEMS) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Slf4j | ||||||
| public class VirtualDataSetIteratorTest { | public class VirtualDataSetIteratorTest { | ||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void setUp() throws Exception {} |     public void setUp() throws Exception {} | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  |     @BeforeAll | ||||||
|  |     @SneakyThrows | ||||||
|  |     public static void beforeAll() { | ||||||
|  |         if(Platform.isWindows()) { | ||||||
|  |             File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp"); | ||||||
|  |             File binDir = new File(hadoopHome,"bin"); | ||||||
|  |             if(!binDir.exists()) | ||||||
|  |                 binDir.mkdirs(); | ||||||
|  |             File outputFile = new File(binDir,"winutils.exe"); | ||||||
|  |             if(!outputFile.exists()) { | ||||||
|  |                 log.info("Fixing spark for windows"); | ||||||
|  |                 Downloader.download("winutils.exe", | ||||||
|  |                         URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(), | ||||||
|  |                         outputFile,"db24b404d2331a1bec7443336a5171f1",3); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testSimple1() throws Exception { |     public void testSimple1() throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         List<Iterator<DataSet>> iterators = new ArrayList<>(); |         List<Iterator<DataSet>> iterators = new ArrayList<>(); | ||||||
| 
 | 
 | ||||||
|         List<DataSet> first = new ArrayList<>(); |         List<DataSet> first = new ArrayList<>(); | ||||||
|  | |||||||
| @ -21,12 +21,18 @@ | |||||||
| package org.deeplearning4j.spark.parameterserver.iterators; | package org.deeplearning4j.spark.parameterserver.iterators; | ||||||
| 
 | 
 | ||||||
| import com.sun.jna.Platform; | import com.sun.jna.Platform; | ||||||
|  | import lombok.SneakyThrows; | ||||||
|  | import lombok.extern.slf4j.Slf4j; | ||||||
|  | import org.junit.jupiter.api.BeforeAll; | ||||||
| import org.junit.jupiter.api.BeforeEach; | import org.junit.jupiter.api.BeforeEach; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.nd4j.common.resources.Downloader; | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| import org.nd4j.common.tests.tags.TagNames; | import org.nd4j.common.tests.tags.TagNames; | ||||||
| 
 | 
 | ||||||
|  | import java.io.File; | ||||||
|  | import java.net.URI; | ||||||
| import java.util.ArrayList; | import java.util.ArrayList; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| 
 | 
 | ||||||
| @ -35,18 +41,36 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | |||||||
| @Tag(TagNames.SPARK) | @Tag(TagNames.SPARK) | ||||||
| @Tag(TagNames.DIST_SYSTEMS) | @Tag(TagNames.DIST_SYSTEMS) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Slf4j | ||||||
| public class VirtualIteratorTest { | public class VirtualIteratorTest { | ||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void setUp() throws Exception { |     public void setUp() throws Exception { | ||||||
|         // |         // | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  |     @BeforeAll | ||||||
|  |     @SneakyThrows | ||||||
|  |     public static void beforeAll() { | ||||||
|  |         if(Platform.isWindows()) { | ||||||
|  |             File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp"); | ||||||
|  |             File binDir = new File(hadoopHome,"bin"); | ||||||
|  |             if(!binDir.exists()) | ||||||
|  |                 binDir.mkdirs(); | ||||||
|  |             File outputFile = new File(binDir,"winutils.exe"); | ||||||
|  |             if(!outputFile.exists()) { | ||||||
|  |                 log.info("Fixing spark for windows"); | ||||||
|  |                 Downloader.download("winutils.exe", | ||||||
|  |                         URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(), | ||||||
|  |                         outputFile,"db24b404d2331a1bec7443336a5171f1",3); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testIteration1() throws Exception { |     public void testIteration1() throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         List<Integer> integers = new ArrayList<>(); |         List<Integer> integers = new ArrayList<>(); | ||||||
|         for (int i = 0; i < 100; i++) { |         for (int i = 0; i < 100; i++) { | ||||||
|             integers.add(i); |             integers.add(i); | ||||||
|  | |||||||
| @ -21,19 +21,24 @@ | |||||||
| package org.deeplearning4j.spark.parameterserver.modelimport.elephas; | package org.deeplearning4j.spark.parameterserver.modelimport.elephas; | ||||||
| 
 | 
 | ||||||
| import com.sun.jna.Platform; | import com.sun.jna.Platform; | ||||||
|  | import lombok.SneakyThrows; | ||||||
|  | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.apache.spark.api.java.JavaSparkContext; | import org.apache.spark.api.java.JavaSparkContext; | ||||||
| import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; | import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; | ||||||
| import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; | import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; | ||||||
| import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; | import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; | ||||||
| import org.deeplearning4j.spark.parameterserver.BaseSparkTest; | import org.deeplearning4j.spark.parameterserver.BaseSparkTest; | ||||||
| import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster; | import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster; | ||||||
|  | import org.junit.jupiter.api.BeforeAll; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
| import org.nd4j.common.io.ClassPathResource; | import org.nd4j.common.io.ClassPathResource; | ||||||
|  | import org.nd4j.common.resources.Downloader; | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| import org.nd4j.common.tests.tags.TagNames; | import org.nd4j.common.tests.tags.TagNames; | ||||||
| 
 | 
 | ||||||
| import java.io.File; | import java.io.File; | ||||||
|  | import java.net.URI; | ||||||
| import java.nio.file.Files; | import java.nio.file.Files; | ||||||
| import java.nio.file.StandardCopyOption; | import java.nio.file.StandardCopyOption; | ||||||
| 
 | 
 | ||||||
| @ -43,14 +48,32 @@ import static org.junit.jupiter.api.Assertions.assertTrue; | |||||||
| @Tag(TagNames.SPARK) | @Tag(TagNames.SPARK) | ||||||
| @Tag(TagNames.DIST_SYSTEMS) | @Tag(TagNames.DIST_SYSTEMS) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Slf4j | ||||||
| public class TestElephasImport extends BaseSparkTest { | public class TestElephasImport extends BaseSparkTest { | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  |     @BeforeAll | ||||||
|  |     @SneakyThrows | ||||||
|  |     public static void beforeAll() { | ||||||
|  |         if(Platform.isWindows()) { | ||||||
|  |             File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp"); | ||||||
|  |             File binDir = new File(hadoopHome,"bin"); | ||||||
|  |             if(!binDir.exists()) | ||||||
|  |                 binDir.mkdirs(); | ||||||
|  |             File outputFile = new File(binDir,"winutils.exe"); | ||||||
|  |             if(!outputFile.exists()) { | ||||||
|  |                 log.info("Fixing spark for windows"); | ||||||
|  |                 Downloader.download("winutils.exe", | ||||||
|  |                         URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(), | ||||||
|  |                         outputFile,"db24b404d2331a1bec7443336a5171f1",3); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testElephasSequentialImport() throws Exception { |     public void testElephasSequentialImport() throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         String modelPath = "modelimport/elephas/elephas_sequential.h5"; |         String modelPath = "modelimport/elephas/elephas_sequential.h5"; | ||||||
|         SparkDl4jMultiLayer model = importElephasSequential(sc, modelPath); |         SparkDl4jMultiLayer model = importElephasSequential(sc, modelPath); | ||||||
|         // System.out.println(model.getNetwork().summary()); |         // System.out.println(model.getNetwork().summary()); | ||||||
|  | |||||||
| @ -44,7 +44,6 @@ import org.deeplearning4j.spark.api.RDDTrainingApproach; | |||||||
| import org.deeplearning4j.spark.api.TrainingMaster; | import org.deeplearning4j.spark.api.TrainingMaster; | ||||||
| import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; | import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; | ||||||
| import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; | import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; | ||||||
| import org.deeplearning4j.spark.parameterserver.BaseSparkTest; |  | ||||||
| import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster; | import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster; | ||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.Disabled; | ||||||
| 
 | 
 | ||||||
| @ -75,6 +74,7 @@ import java.util.*; | |||||||
| import java.util.concurrent.ConcurrentHashMap; | import java.util.concurrent.ConcurrentHashMap; | ||||||
| 
 | 
 | ||||||
| import static org.junit.jupiter.api.Assertions.*; | import static org.junit.jupiter.api.Assertions.*; | ||||||
|  | import org.deeplearning4j.spark.parameterserver.BaseSparkTest; | ||||||
| 
 | 
 | ||||||
| @Slf4j | @Slf4j | ||||||
| //@Disabled("AB 2019/05/21 - Failing - Issue #7657") | //@Disabled("AB 2019/05/21 - Failing - Issue #7657") | ||||||
| @ -82,6 +82,8 @@ import static org.junit.jupiter.api.Assertions.*; | |||||||
| @Tag(TagNames.SPARK) | @Tag(TagNames.SPARK) | ||||||
| @Tag(TagNames.DIST_SYSTEMS) | @Tag(TagNames.DIST_SYSTEMS) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Tag(TagNames.LONG_TEST) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
| public class GradientSharingTrainingTest extends BaseSparkTest { | public class GradientSharingTrainingTest extends BaseSparkTest { | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -339,11 +341,12 @@ public class GradientSharingTrainingTest extends BaseSparkTest { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     @Test @Disabled |     @Test | ||||||
|     public void testEpochUpdating(@TempDir Path testDir) throws Exception { |     public void testEpochUpdating(@TempDir Path testDir) throws Exception { | ||||||
|         //Ensure that epoch counter is incremented properly on the workers |         //Ensure that epoch counter is incremented properly on the workers | ||||||
| 
 | 
 | ||||||
|         File temp = testDir.toFile(); |         File temp = testDir.resolve("new-dir-" + UUID.randomUUID().toString()).toFile(); | ||||||
|  |         temp.mkdirs(); | ||||||
| 
 | 
 | ||||||
|         //TODO this probably won't work everywhere... |         //TODO this probably won't work everywhere... | ||||||
|         String controller = Inet4Address.getLocalHost().getHostAddress(); |         String controller = Inet4Address.getLocalHost().getHostAddress(); | ||||||
| @ -394,7 +397,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         JavaRDD<String> pathRdd = sc.parallelize(paths); |         JavaRDD<String> pathRdd = sc.parallelize(paths); | ||||||
|         for( int i=0; i<3; i++ ) { |         for( int i = 0; i < 3; i++) { | ||||||
|             ThresholdAlgorithm ta = tm.getThresholdAlgorithm(); |             ThresholdAlgorithm ta = tm.getThresholdAlgorithm(); | ||||||
|             sparkNet.fitPaths(pathRdd); |             sparkNet.fitPaths(pathRdd); | ||||||
|             //Check also that threshold algorithm was updated/averaged |             //Check also that threshold algorithm was updated/averaged | ||||||
|  | |||||||
| @ -1,123 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.spark.iterator; |  | ||||||
| 
 |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.apache.spark.TaskContext; |  | ||||||
| import org.apache.spark.TaskContextHelper; |  | ||||||
| import org.nd4j.linalg.dataset.AsyncDataSetIterator; |  | ||||||
| import org.nd4j.linalg.api.memory.MemoryWorkspace; |  | ||||||
| import org.nd4j.linalg.dataset.DataSet; |  | ||||||
| import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; |  | ||||||
| import org.nd4j.linalg.dataset.callbacks.DataSetCallback; |  | ||||||
| import org.nd4j.linalg.dataset.callbacks.DefaultCallback; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.util.concurrent.BlockingQueue; |  | ||||||
| import java.util.concurrent.LinkedBlockingQueue; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class SparkADSI extends AsyncDataSetIterator { |  | ||||||
|     protected TaskContext context; |  | ||||||
| 
 |  | ||||||
|     protected SparkADSI() { |  | ||||||
|         super(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkADSI(DataSetIterator baseIterator) { |  | ||||||
|         this(baseIterator, 8); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue) { |  | ||||||
|         this(iterator, queueSize, queue, true); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkADSI(DataSetIterator baseIterator, int queueSize) { |  | ||||||
|         this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkADSI(DataSetIterator baseIterator, int queueSize, boolean useWorkspace) { |  | ||||||
|         this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkADSI(DataSetIterator baseIterator, int queueSize, boolean useWorkspace, Integer deviceId) { |  | ||||||
|         this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace, new DefaultCallback(), |  | ||||||
|                         deviceId); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkADSI(DataSetIterator baseIterator, int queueSize, boolean useWorkspace, DataSetCallback callback) { |  | ||||||
|         this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace, callback); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace) { |  | ||||||
|         this(iterator, queueSize, queue, useWorkspace, new DefaultCallback()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace, |  | ||||||
|                     DataSetCallback callback) { |  | ||||||
|         this(iterator, queueSize, queue, useWorkspace, callback, Nd4j.getAffinityManager().getDeviceForCurrentThread()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace, |  | ||||||
|                     DataSetCallback callback, Integer deviceId) { |  | ||||||
|         this(); |  | ||||||
| 
 |  | ||||||
|         if (queueSize < 2) |  | ||||||
|             queueSize = 2; |  | ||||||
| 
 |  | ||||||
|         this.deviceId = deviceId; |  | ||||||
|         this.callback = callback; |  | ||||||
|         this.useWorkspace = useWorkspace; |  | ||||||
|         this.buffer = queue; |  | ||||||
|         this.prefetchSize = queueSize; |  | ||||||
|         this.backedIterator = iterator; |  | ||||||
|         this.workspaceId = "SADSI_ITER-" + java.util.UUID.randomUUID().toString(); |  | ||||||
| 
 |  | ||||||
|         if (iterator.resetSupported()) |  | ||||||
|             this.backedIterator.reset(); |  | ||||||
| 
 |  | ||||||
|         context = TaskContext.get(); |  | ||||||
| 
 |  | ||||||
|         this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null, Nd4j.getAffinityManager().getDeviceForCurrentThread()); |  | ||||||
| 
 |  | ||||||
|         /** |  | ||||||
|          * We want to ensure, that background thread will have the same thread->device affinity, as master thread |  | ||||||
|          */ |  | ||||||
| 
 |  | ||||||
|         thread.setDaemon(true); |  | ||||||
|         thread.start(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected void externalCall() { |  | ||||||
|         TaskContextHelper.setTaskContext(context); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public class SparkPrefetchThread extends AsyncPrefetchThread { |  | ||||||
| 
 |  | ||||||
|         protected SparkPrefetchThread(BlockingQueue<DataSet> queue, DataSetIterator iterator, DataSet terminator, MemoryWorkspace workspace, int deviceId) { |  | ||||||
|             super(queue, iterator, terminator, workspace, deviceId); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,118 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.spark.iterator; |  | ||||||
| 
 |  | ||||||
| import lombok.NonNull; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.apache.spark.TaskContext; |  | ||||||
| import org.apache.spark.TaskContextHelper; |  | ||||||
| import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator; |  | ||||||
| import org.nd4j.linalg.dataset.api.MultiDataSet; |  | ||||||
| import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; |  | ||||||
| import org.nd4j.linalg.dataset.callbacks.DataSetCallback; |  | ||||||
| import org.nd4j.linalg.dataset.callbacks.DefaultCallback; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.util.concurrent.BlockingQueue; |  | ||||||
| import java.util.concurrent.LinkedBlockingQueue; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class SparkAMDSI extends AsyncMultiDataSetIterator { |  | ||||||
|     protected TaskContext context; |  | ||||||
| 
 |  | ||||||
|     protected SparkAMDSI() { |  | ||||||
|         super(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkAMDSI(MultiDataSetIterator baseIterator) { |  | ||||||
|         this(baseIterator, 8); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue) { |  | ||||||
|         this(iterator, queueSize, queue, true); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize) { |  | ||||||
|         this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace) { |  | ||||||
|         this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize), useWorkspace); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace, Integer deviceId) { |  | ||||||
|         this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize), useWorkspace, |  | ||||||
|                         new DefaultCallback(), deviceId); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace, |  | ||||||
|                     DataSetCallback callback) { |  | ||||||
|         this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize), useWorkspace, callback); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue, |  | ||||||
|                     boolean useWorkspace) { |  | ||||||
|         this(iterator, queueSize, queue, useWorkspace, null); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue, |  | ||||||
|                     boolean useWorkspace, DataSetCallback callback) { |  | ||||||
|         this(iterator, queueSize, queue, useWorkspace, callback, Nd4j.getAffinityManager().getDeviceForCurrentThread()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue, |  | ||||||
|                     boolean useWorkspace, DataSetCallback callback, Integer deviceId) { |  | ||||||
|         this(); |  | ||||||
| 
 |  | ||||||
|         if (queueSize < 2) |  | ||||||
|             queueSize = 2; |  | ||||||
| 
 |  | ||||||
|         this.callback = callback; |  | ||||||
|         this.buffer = queue; |  | ||||||
|         this.backedIterator = iterator; |  | ||||||
|         this.useWorkspaces = useWorkspace; |  | ||||||
|         this.prefetchSize = queueSize; |  | ||||||
|         this.workspaceId = "SAMDSI_ITER-" + java.util.UUID.randomUUID().toString(); |  | ||||||
|         this.deviceId = deviceId; |  | ||||||
| 
 |  | ||||||
|         if (iterator.resetSupported()) |  | ||||||
|             this.backedIterator.reset(); |  | ||||||
| 
 |  | ||||||
|         this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread()); |  | ||||||
| 
 |  | ||||||
|         context = TaskContext.get(); |  | ||||||
| 
 |  | ||||||
|         thread.setDaemon(true); |  | ||||||
|         thread.start(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected void externalCall() { |  | ||||||
|         TaskContextHelper.setTaskContext(context); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected class SparkPrefetchThread extends AsyncPrefetchThread { |  | ||||||
| 
 |  | ||||||
|         protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue, @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator, int deviceId) { |  | ||||||
|             super(queue, iterator, terminator, deviceId); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -64,10 +64,6 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testEarlyStoppingIris() { |     public void testEarlyStoppingIris() { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() |         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | ||||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) |                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||||
|                         .updater(new Sgd()).weightInit(WeightInit.XAVIER).list() |                         .updater(new Sgd()).weightInit(WeightInit.XAVIER).list() | ||||||
|  | |||||||
| @ -67,10 +67,6 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testEarlyStoppingIris() { |     public void testEarlyStoppingIris() { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() |         ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() | ||||||
|                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) |                         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | ||||||
|                         .updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") |                         .updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") | ||||||
|  | |||||||
| @ -76,10 +76,6 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testDataVecDataSetFunction(@TempDir Path testDir) throws Exception { |     public void testDataVecDataSetFunction(@TempDir Path testDir) throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         JavaSparkContext sc = getContext(); |         JavaSparkContext sc = getContext(); | ||||||
| 
 | 
 | ||||||
|         File f = testDir.toFile(); |         File f = testDir.toFile(); | ||||||
|  | |||||||
| @ -51,10 +51,6 @@ public class TestExport extends BaseSparkTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testBatchAndExportDataSetsFunction() throws Exception { |     public void testBatchAndExportDataSetsFunction() throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         String baseDir = System.getProperty("java.io.tmpdir"); |         String baseDir = System.getProperty("java.io.tmpdir"); | ||||||
|         baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExport/"); |         baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExport/"); | ||||||
|         baseDir = baseDir.replaceAll("\\\\", "/"); |         baseDir = baseDir.replaceAll("\\\\", "/"); | ||||||
|  | |||||||
| @ -70,10 +70,6 @@ public class TestPreProcessedData extends BaseSparkTest { | |||||||
|     @Test |     @Test | ||||||
|     public void testPreprocessedData() { |     public void testPreprocessedData() { | ||||||
|         //Test _loading_ of preprocessed data |         //Test _loading_ of preprocessed data | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         int dataSetObjSize = 5; |         int dataSetObjSize = 5; | ||||||
|         int batchSizePerExecutor = 10; |         int batchSizePerExecutor = 10; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -52,10 +52,6 @@ public class TestCustomLayer extends BaseSparkTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testSparkWithCustomLayer() { |     public void testSparkWithCustomLayer() { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         //Basic test - checks whether exceptions etc are thrown with custom layers + spark |         //Basic test - checks whether exceptions etc are thrown with custom layers + spark | ||||||
|         //Custom layers are tested more extensively in dl4j core |         //Custom layers are tested more extensively in dl4j core | ||||||
|         MultiLayerConfiguration conf = |         MultiLayerConfiguration conf = | ||||||
|  | |||||||
| @ -77,10 +77,6 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testEvaluationSimple() throws Exception { |     public void testEvaluationSimple() throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         Nd4j.getRandom().setSeed(12345); |         Nd4j.getRandom().setSeed(12345); | ||||||
| 
 | 
 | ||||||
|         for( int evalWorkers : new int[]{1, 4, 8}) { |         for( int evalWorkers : new int[]{1, 4, 8}) { | ||||||
|  | |||||||
| @ -61,10 +61,6 @@ public class TestTrainingStatsCollection extends BaseSparkTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testStatsCollection() throws Exception { |     public void testStatsCollection() throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         int nWorkers = numExecutors(); |         int nWorkers = numExecutors(); | ||||||
| 
 | 
 | ||||||
|         JavaSparkContext sc = getContext(); |         JavaSparkContext sc = getContext(); | ||||||
|  | |||||||
| @ -60,10 +60,6 @@ public class TestListeners extends BaseSparkTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testStatsCollection() { |     public void testStatsCollection() { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         JavaSparkContext sc = getContext(); |         JavaSparkContext sc = getContext(); | ||||||
|         int nExecutors = numExecutors(); |         int nExecutors = numExecutors(); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -54,10 +54,6 @@ public class TestRepartitioning extends BaseSparkTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testRepartitioning() { |     public void testRepartitioning() { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         List<String> list = new ArrayList<>(); |         List<String> list = new ArrayList<>(); | ||||||
|         for (int i = 0; i < 1000; i++) { |         for (int i = 0; i < 1000; i++) { | ||||||
|             list.add(String.valueOf(i)); |             list.add(String.valueOf(i)); | ||||||
|  | |||||||
| @ -52,10 +52,6 @@ public class TestValidation extends BaseSparkTest { | |||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
|     public void testDataSetValidation(@TempDir Path folder) throws Exception { |     public void testDataSetValidation(@TempDir Path folder) throws Exception { | ||||||
|         if(Platform.isWindows()) { |  | ||||||
|             //Spark tests don't run on windows |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         File f = folder.toFile(); |         File f = folder.toFile(); | ||||||
| 
 | 
 | ||||||
|         for( int i = 0; i < 3; i++ ) { |         for( int i = 0; i < 3; i++ ) { | ||||||
|  | |||||||
| @ -38,11 +38,17 @@ | |||||||
|         <module>deeplearning4j-ui</module> |         <module>deeplearning4j-ui</module> | ||||||
|         <module>deeplearning4j-ui-components</module> |         <module>deeplearning4j-ui-components</module> | ||||||
|         <module>deeplearning4j-ui-model</module> |         <module>deeplearning4j-ui-model</module> | ||||||
|         <module>deeplearning4j-ui-standalone</module> |  | ||||||
|         <module>deeplearning4j-vertx</module> |         <module>deeplearning4j-vertx</module> | ||||||
|     </modules> |     </modules> | ||||||
| 
 | 
 | ||||||
|     <profiles> |     <profiles> | ||||||
|  |         <profile> | ||||||
|  |             <id>ui-jar</id> | ||||||
|  |             <modules> | ||||||
|  |                 <module>deeplearning4j-ui-standalone</module> | ||||||
|  |             </modules> | ||||||
|  |         </profile> | ||||||
|  | 
 | ||||||
|         <profile> |         <profile> | ||||||
|             <id>nd4j-tests-cpu</id> |             <id>nd4j-tests-cpu</id> | ||||||
|         </profile> |         </profile> | ||||||
|  | |||||||
| @ -41,6 +41,7 @@ import java.io.File; | |||||||
| @Tag(TagNames.DL4J_OLD_API) | @Tag(TagNames.DL4J_OLD_API) | ||||||
| @NativeTag | @NativeTag | ||||||
| @Tag(TagNames.LONG_TEST) | @Tag(TagNames.LONG_TEST) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
| public class MiscTests extends BaseDL4JTest { | public class MiscTests extends BaseDL4JTest { | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|  | |||||||
| @ -51,6 +51,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | |||||||
| @Tag(TagNames.DL4J_OLD_API) | @Tag(TagNames.DL4J_OLD_API) | ||||||
| @NativeTag | @NativeTag | ||||||
| @Tag(TagNames.LONG_TEST) | @Tag(TagNames.LONG_TEST) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
| public class TestDownload extends BaseDL4JTest { | public class TestDownload extends BaseDL4JTest { | ||||||
|     @TempDir |     @TempDir | ||||||
|     static Path sharedTempDir; |     static Path sharedTempDir; | ||||||
|  | |||||||
| @ -61,6 +61,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; | |||||||
| @Tag(TagNames.DL4J_OLD_API) | @Tag(TagNames.DL4J_OLD_API) | ||||||
| @NativeTag | @NativeTag | ||||||
| @Tag(TagNames.LONG_TEST) | @Tag(TagNames.LONG_TEST) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
| public class TestImageNet extends BaseDL4JTest { | public class TestImageNet extends BaseDL4JTest { | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|  | |||||||
| @ -2308,7 +2308,7 @@ public class Nd4j { | |||||||
|             data2.add(readSplit(data)); |             data2.add(readSplit(data)); | ||||||
|         } |         } | ||||||
|         float[][] fArr = new float[data2.size()][0]; |         float[][] fArr = new float[data2.size()][0]; | ||||||
|         for(int i=0; i<data2.size(); i++ ){ |         for(int i = 0; i < data2.size(); i++) { | ||||||
|             fArr[i] = data2.get(i); |             fArr[i] = data2.get(i); | ||||||
|         } |         } | ||||||
|         ret = Nd4j.createFromArray(fArr).castTo(dataType); |         ret = Nd4j.createFromArray(fArr).castTo(dataType); | ||||||
| @ -2785,7 +2785,7 @@ public class Nd4j { | |||||||
|      * @return the random ndarray with the specified shape |      * @return the random ndarray with the specified shape | ||||||
|      */ |      */ | ||||||
|     public static INDArray rand(@NonNull int... shape) { |     public static INDArray rand(@NonNull int... shape) { | ||||||
|         INDArray ret = createUninitialized(shape, order()).castTo(Nd4j.defaultFloatingPointType()); //INSTANCE.rand(shape, Nd4j.getRandom()); |         INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom()); | ||||||
|         return rand(ret); |         return rand(ret); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -2793,7 +2793,7 @@ public class Nd4j { | |||||||
|      * See {@link #rand(int[])} |      * See {@link #rand(int[])} | ||||||
|      */ |      */ | ||||||
|     public static INDArray rand(@NonNull long... shape) { |     public static INDArray rand(@NonNull long... shape) { | ||||||
|         INDArray ret = createUninitialized(shape, order()).castTo(Nd4j.defaultFloatingPointType()); //INSTANCE.rand(shape, Nd4j.getRandom()); |         INDArray ret = createUninitialized(shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom()); | ||||||
|         return rand(ret); |         return rand(ret); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -2806,7 +2806,7 @@ public class Nd4j { | |||||||
|     public static INDArray rand(@NonNull DataType dataType, @NonNull long... shape) { |     public static INDArray rand(@NonNull DataType dataType, @NonNull long... shape) { | ||||||
|         Preconditions.checkArgument(dataType.isFPType(), |         Preconditions.checkArgument(dataType.isFPType(), | ||||||
|                 "Can't create a random array of a non-floating point data type"); |                 "Can't create a random array of a non-floating point data type"); | ||||||
|         INDArray ret = createUninitialized(dataType, shape, order()).castTo(Nd4j.defaultFloatingPointType()); //INSTANCE.rand(shape, Nd4j.getRandom()); |         INDArray ret = createUninitialized(dataType, shape, order()); //INSTANCE.rand(shape, Nd4j.getRandom()); | ||||||
|         return rand(ret); |         return rand(ret); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -2820,7 +2820,7 @@ public class Nd4j { | |||||||
|      * @return the random ndarray with the specified shape |      * @return the random ndarray with the specified shape | ||||||
|      */ |      */ | ||||||
|     public static INDArray rand(char order, @NonNull int... shape) { |     public static INDArray rand(char order, @NonNull int... shape) { | ||||||
|         INDArray ret = Nd4j.createUninitialized(shape, order).castTo(Nd4j.defaultFloatingPointType()); //INSTANCE.rand(order, shape); |         INDArray ret = Nd4j.createUninitialized(shape, order); //INSTANCE.rand(order, shape); | ||||||
|         return rand(ret); |         return rand(ret); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -2829,7 +2829,7 @@ public class Nd4j { | |||||||
|      */ |      */ | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) { |     public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) { | ||||||
|         return rand(dataType, order, ArrayUtil.toLongArray(shape)).castTo(Nd4j.defaultFloatingPointType()); |         return rand(dataType, order, ArrayUtil.toLongArray(shape)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
| @ -2837,7 +2837,7 @@ public class Nd4j { | |||||||
|      */ |      */ | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public static INDArray rand(@NonNull DataType dataType, char order, @NonNull int... shape) { |     public static INDArray rand(@NonNull DataType dataType, char order, @NonNull int... shape) { | ||||||
|         return rand(dataType, order, ArrayUtil.toLongArray(shape)).castTo(Nd4j.defaultFloatingPointType()); |         return rand(dataType, order, ArrayUtil.toLongArray(shape)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
| @ -2851,7 +2851,7 @@ public class Nd4j { | |||||||
|      * @return the random ndarray with the specified shape |      * @return the random ndarray with the specified shape | ||||||
|      */ |      */ | ||||||
|     public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) { |     public static INDArray rand(@NonNull DataType dataType, char order, @NonNull long... shape) { | ||||||
|         INDArray ret = Nd4j.createUninitialized(dataType, shape, order).castTo(Nd4j.defaultFloatingPointType()); |         INDArray ret = Nd4j.createUninitialized(dataType, shape, order); | ||||||
|         return rand(ret); |         return rand(ret); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -2866,7 +2866,7 @@ public class Nd4j { | |||||||
|      * @return the random ndarray with the specified shape |      * @return the random ndarray with the specified shape | ||||||
|      */ |      */ | ||||||
|     public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) { |     public static INDArray rand(@NonNull DataType dataType, @NonNull int... shape) { | ||||||
|         INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()).castTo(Nd4j.defaultFloatingPointType()); |         INDArray ret = Nd4j.createUninitialized(dataType, ArrayUtil.toLongArray(shape), Nd4j.order()); | ||||||
|         return rand(ret); |         return rand(ret); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -2911,7 +2911,7 @@ public class Nd4j { | |||||||
|      * @return the random ndarray with the specified shape |      * @return the random ndarray with the specified shape | ||||||
|      */ |      */ | ||||||
|     public static INDArray rand(long seed, @NonNull long... shape) { |     public static INDArray rand(long seed, @NonNull long... shape) { | ||||||
|         INDArray ret = createUninitialized(shape, Nd4j.order()).castTo(Nd4j.defaultFloatingPointType());//;INSTANCE.rand(shape, seed); |         INDArray ret = createUninitialized(shape, Nd4j.order());//;INSTANCE.rand(shape, seed); | ||||||
|         return rand(ret, seed); |         return rand(ret, seed); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -2920,7 +2920,7 @@ public class Nd4j { | |||||||
|      */ |      */ | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public static INDArray rand(int[] shape, long seed) { |     public static INDArray rand(int[] shape, long seed) { | ||||||
|         return rand(seed, ArrayUtil.toLongArray(shape)).castTo(Nd4j.defaultFloatingPointType()); |         return rand(seed, ArrayUtil.toLongArray(shape)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -2943,7 +2943,7 @@ public class Nd4j { | |||||||
|      */ |      */ | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public static INDArray rand(int[] shape, @NonNull org.nd4j.linalg.api.rng.Random rng) { |     public static INDArray rand(int[] shape, @NonNull org.nd4j.linalg.api.rng.Random rng) { | ||||||
|         return rand(rng, ArrayUtil.toLongArray(shape)).castTo(Nd4j.defaultFloatingPointType()); |         return rand(rng, ArrayUtil.toLongArray(shape)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
| @ -2954,7 +2954,7 @@ public class Nd4j { | |||||||
|      * @return the random ndarray with the specified shape |      * @return the random ndarray with the specified shape | ||||||
|      */ |      */ | ||||||
|     public static INDArray rand(@NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) { |     public static INDArray rand(@NonNull org.nd4j.linalg.api.rng.Random rng, @NonNull long... shape) { | ||||||
|         INDArray ret = createUninitialized(shape, Nd4j.order()).castTo(Nd4j.defaultFloatingPointType()); //INSTANCE.rand(shape, rng); |         INDArray ret = createUninitialized(shape, Nd4j.order()); //INSTANCE.rand(shape, rng); | ||||||
|         return rand(ret, rng); |         return rand(ret, rng); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -2963,7 +2963,7 @@ public class Nd4j { | |||||||
|      */ |      */ | ||||||
|     @Deprecated |     @Deprecated | ||||||
|     public static INDArray rand(int[] shape, @NonNull Distribution dist) { |     public static INDArray rand(int[] shape, @NonNull Distribution dist) { | ||||||
|         return rand(dist, ArrayUtil.toLongArray(shape)).castTo(Nd4j.defaultFloatingPointType()); |         return rand(dist, ArrayUtil.toLongArray(shape)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|  | |||||||
| @ -75,7 +75,8 @@ public class BinarySerde { | |||||||
|         ByteBuffer byteBuffer = buffer.hasArray() ? ByteBuffer.allocateDirect(buffer.array().length).put(buffer.array()) |         ByteBuffer byteBuffer = buffer.hasArray() ? ByteBuffer.allocateDirect(buffer.array().length).put(buffer.array()) | ||||||
|                 .order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder()); |                 .order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder()); | ||||||
|         //bump the byte buffer to the proper position |         //bump the byte buffer to the proper position | ||||||
|         byteBuffer.position(offset); |         Buffer buffer1 = (Buffer) byteBuffer; | ||||||
|  |         buffer1.position(offset); | ||||||
|         int rank = byteBuffer.getInt(); |         int rank = byteBuffer.getInt(); | ||||||
|         if (rank < 0) |         if (rank < 0) | ||||||
|             throw new IllegalStateException("Found negative integer. Corrupt serialization?"); |             throw new IllegalStateException("Found negative integer. Corrupt serialization?"); | ||||||
| @ -99,7 +100,8 @@ public class BinarySerde { | |||||||
|             DataBuffer buff = Nd4j.createBuffer(slice, type, (int) Shape.length(shapeBuff)); |             DataBuffer buff = Nd4j.createBuffer(slice, type, (int) Shape.length(shapeBuff)); | ||||||
|             //advance past the data |             //advance past the data | ||||||
|             int position = byteBuffer.position() + (buff.getElementSize() * (int) buff.length()); |             int position = byteBuffer.position() + (buff.getElementSize() * (int) buff.length()); | ||||||
|             byteBuffer.position(position); |             Buffer buffer2 = (Buffer) byteBuffer; | ||||||
|  |             buffer2.position(position); | ||||||
|             //create the final array |             //create the final array | ||||||
|             //TODO: see how to avoid dup here |             //TODO: see how to avoid dup here | ||||||
|             INDArray arr = Nd4j.createArrayFromShapeBuffer(buff.dup(), shapeBuff.dup()); |             INDArray arr = Nd4j.createArrayFromShapeBuffer(buff.dup(), shapeBuff.dup()); | ||||||
| @ -116,7 +118,8 @@ public class BinarySerde { | |||||||
|             INDArray arr = Nd4j.createArrayFromShapeBuffer(compressedDataBuffer.dup(), shapeBuff.dup()); |             INDArray arr = Nd4j.createArrayFromShapeBuffer(compressedDataBuffer.dup(), shapeBuff.dup()); | ||||||
|             //advance past the data |             //advance past the data | ||||||
|             int compressLength = (int) compressionDescriptor.getCompressedLength(); |             int compressLength = (int) compressionDescriptor.getCompressedLength(); | ||||||
|             byteBuffer.position(byteBuffer.position() + compressLength); |             Buffer buffer2 = (Buffer) byteBuffer; | ||||||
|  |             buffer2.position(buffer2.position() + compressLength); | ||||||
|             return Pair.of(arr, byteBuffer); |             return Pair.of(arr, byteBuffer); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -140,6 +140,7 @@ | |||||||
|                 <groupId>org.apache.maven.plugins</groupId> |                 <groupId>org.apache.maven.plugins</groupId> | ||||||
|                 <artifactId>maven-surefire-plugin</artifactId> |                 <artifactId>maven-surefire-plugin</artifactId> | ||||||
|                 <configuration> |                 <configuration> | ||||||
|  |                     <forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/> | ||||||
|                     <forkCount>${cpu.core.count}</forkCount> |                     <forkCount>${cpu.core.count}</forkCount> | ||||||
|                     <reuseForks>false</reuseForks> |                     <reuseForks>false</reuseForks> | ||||||
|                     <environmentVariables> |                     <environmentVariables> | ||||||
| @ -162,7 +163,7 @@ | |||||||
|                         Maximum heap size was set to 6g, as a minimum required value for tests run. |                         Maximum heap size was set to 6g, as a minimum required value for tests run. | ||||||
|                         Depending on a build machine, default value is not always enough. |                         Depending on a build machine, default value is not always enough. | ||||||
|                     --> |                     --> | ||||||
|                     <argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g</argLine> |                     <argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine> | ||||||
|                 </configuration> |                 </configuration> | ||||||
|             </plugin> |             </plugin> | ||||||
|             <plugin> |             <plugin> | ||||||
|  | |||||||
| @ -526,7 +526,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { | |||||||
|     public INDArray toFlattened(char order, Collection<INDArray> matrices) { |     public INDArray toFlattened(char order, Collection<INDArray> matrices) { | ||||||
|         Preconditions.checkArgument(matrices.size() > 0, "toFlattened expects > 0 operands"); |         Preconditions.checkArgument(matrices.size() > 0, "toFlattened expects > 0 operands"); | ||||||
| 
 | 
 | ||||||
|         return Nd4j.exec(new Flatten(order, matrices.toArray(new INDArray[matrices.size()])))[0]; |         return Nd4j.exec(new Flatten(order, matrices.toArray(new INDArray[matrices.size()])))[0] | ||||||
|  |                 .castTo(matrices.iterator().next().dataType()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|  | |||||||
| @ -124,6 +124,7 @@ | |||||||
|                     <groupId>org.apache.maven.plugins</groupId> |                     <groupId>org.apache.maven.plugins</groupId> | ||||||
|                     <artifactId>maven-surefire-plugin</artifactId> |                     <artifactId>maven-surefire-plugin</artifactId> | ||||||
|                     <configuration> |                     <configuration> | ||||||
|  |                         <forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/> | ||||||
|                         <forkCount>${cpu.core.count}</forkCount> |                         <forkCount>${cpu.core.count}</forkCount> | ||||||
|                         <reuseForks>false</reuseForks> |                         <reuseForks>false</reuseForks> | ||||||
|                         <environmentVariables> |                         <environmentVariables> | ||||||
| @ -139,7 +140,12 @@ | |||||||
|                             Maximum heap size was set to 8g, as a minimum required value for tests run. |                             Maximum heap size was set to 8g, as a minimum required value for tests run. | ||||||
|                             Depending on a build machine, default value is not always enough. |                             Depending on a build machine, default value is not always enough. | ||||||
|                         --> |                         --> | ||||||
|                         <argLine>-Xmx2g</argLine> |                         <argLine>-Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine> | ||||||
|  |                         <forkedProcessTimeoutInSeconds>240</forkedProcessTimeoutInSeconds> | ||||||
|  |                         <forkedProcessExitTimeoutInSeconds>240</forkedProcessExitTimeoutInSeconds> | ||||||
|  |                         <parallelTestsTimeoutInSeconds>240</parallelTestsTimeoutInSeconds> | ||||||
|  |                         <parallelTestsTimeoutForcedInSeconds>240</parallelTestsTimeoutForcedInSeconds> | ||||||
|  | 
 | ||||||
|                     </configuration> |                     </configuration> | ||||||
|                 </plugin> |                 </plugin> | ||||||
|                 <plugin> |                 <plugin> | ||||||
|  | |||||||
| @ -269,6 +269,7 @@ | |||||||
|                         <groupId>org.apache.maven.plugins</groupId> |                         <groupId>org.apache.maven.plugins</groupId> | ||||||
|                         <artifactId>maven-surefire-plugin</artifactId> |                         <artifactId>maven-surefire-plugin</artifactId> | ||||||
|                         <configuration> |                         <configuration> | ||||||
|  |                             <forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/> | ||||||
|                             <forkCount>${cpu.core.count}</forkCount> |                             <forkCount>${cpu.core.count}</forkCount> | ||||||
|                             <reuseForks>false</reuseForks> |                             <reuseForks>false</reuseForks> | ||||||
|                             <environmentVariables> |                             <environmentVariables> | ||||||
| @ -304,7 +305,7 @@ | |||||||
| 
 | 
 | ||||||
|                                 For testing large zoo models, this may not be enough (so comment it out). |                                 For testing large zoo models, this may not be enough (so comment it out). | ||||||
|                             --> |                             --> | ||||||
|                             <argLine>-Dfile.encoding=UTF-8 </argLine> |                             <argLine>-Dfile.encoding=UTF-8 -Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine> | ||||||
|                         </configuration> |                         </configuration> | ||||||
|                     </plugin> |                     </plugin> | ||||||
|                 </plugins> |                 </plugins> | ||||||
| @ -350,6 +351,7 @@ | |||||||
|                             </dependency> |                             </dependency> | ||||||
|                         </dependencies> |                         </dependencies> | ||||||
|                         <configuration> |                         <configuration> | ||||||
|  |                             <forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/> | ||||||
|                             <environmentVariables> |                             <environmentVariables> | ||||||
|                                 <LD_LIBRARY_PATH> |                                 <LD_LIBRARY_PATH> | ||||||
|                                     ${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes |                                     ${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes | ||||||
| @ -379,7 +381,12 @@ | |||||||
|                                 Maximum heap size was set to 6g, as a minimum required value for tests run. |                                 Maximum heap size was set to 6g, as a minimum required value for tests run. | ||||||
|                                 Depending on a build machine, default value is not always enough. |                                 Depending on a build machine, default value is not always enough. | ||||||
|                             --> |                             --> | ||||||
|                             <argLine> -Dfile.encoding=UTF-8  -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine> |                             <argLine>-Dfile.encoding=UTF-8 -Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine> | ||||||
|  |                             <forkedProcessTimeoutInSeconds>240</forkedProcessTimeoutInSeconds> | ||||||
|  |                             <forkedProcessExitTimeoutInSeconds>240</forkedProcessExitTimeoutInSeconds> | ||||||
|  |                             <parallelTestsTimeoutInSeconds>240</parallelTestsTimeoutInSeconds> | ||||||
|  |                             <parallelTestsTimeoutForcedInSeconds>240</parallelTestsTimeoutForcedInSeconds> | ||||||
|  | 
 | ||||||
|                         </configuration> |                         </configuration> | ||||||
|                     </plugin> |                     </plugin> | ||||||
|                 </plugins> |                 </plugins> | ||||||
|  | |||||||
| @ -216,6 +216,8 @@ public class TestSessions extends BaseNd4jTestWithBackends { | |||||||
|     @Tag(TagNames.FILE_IO) |     @Tag(TagNames.FILE_IO) | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|     public void testSwitchWhile(Nd4jBackend backend) throws Exception{ |     public void testSwitchWhile(Nd4jBackend backend) throws Exception{ | ||||||
| 
 | 
 | ||||||
|         /* |         /* | ||||||
|  | |||||||
| @ -94,7 +94,6 @@ import org.nd4j.weightinit.impl.UniformInitScheme; | |||||||
| @Tag(TagNames.SAMEDIFF) | @Tag(TagNames.SAMEDIFF) | ||||||
| public class SameDiffTests extends BaseNd4jTestWithBackends { | public class SameDiffTests extends BaseNd4jTestWithBackends { | ||||||
| 
 | 
 | ||||||
|     private DataType initialType; |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
| @ -112,16 +111,11 @@ public class SameDiffTests extends BaseNd4jTestWithBackends { | |||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void before() { |     public void before() { | ||||||
|         Nd4j.create(1); |         Nd4j.create(1); | ||||||
|         initialType = Nd4j.dataType(); |  | ||||||
| 
 |  | ||||||
|         Nd4j.setDataType(DataType.DOUBLE); |  | ||||||
|         Nd4j.getRandom().setSeed(123); |         Nd4j.getRandom().setSeed(123); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @AfterEach |     @AfterEach | ||||||
|     public void after() { |     public void after() { | ||||||
|         Nd4j.setDataType(initialType); |  | ||||||
| 
 |  | ||||||
|         NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); |         NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); | ||||||
|         NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); |         NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); | ||||||
|     } |     } | ||||||
| @ -136,7 +130,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|         INDArray labels = Nd4j.create(new double[]{1, 1, 0, 1}).reshape(4, 1); |         INDArray labels = Nd4j.create(new double[]{1, 1, 0, 1}).reshape(4, 1); | ||||||
| 
 | 
 | ||||||
|         INDArray weights = Nd4j.zeros(3, 1); |         INDArray weights = Nd4j.zeros(3, 1).castTo(labels.dataType()); | ||||||
| 
 | 
 | ||||||
|         Map<String, INDArray> inputMap = new HashMap<>(); |         Map<String, INDArray> inputMap = new HashMap<>(); | ||||||
|         inputMap.put("x", inputs); |         inputMap.put("x", inputs); | ||||||
| @ -155,7 +149,7 @@ public class SameDiffTests extends BaseNd4jTestWithBackends { | |||||||
|         val nodeA = sd.math().square(input); |         val nodeA = sd.math().square(input); | ||||||
|         val nodeB = sd.math().square(nodeA); |         val nodeB = sd.math().square(nodeA); | ||||||
| 
 | 
 | ||||||
|         sd.associateArrayWithVariable(Nd4j.create(new double[]{1, 2, 3, 4, 5, 6}, new long[]{2, 3}), input); |         sd.associateArrayWithVariable(Nd4j.create(new double[]{1, 2, 3, 4, 5, 6}, new long[]{2, 3}).castTo(input.dataType()), input); | ||||||
| 
 | 
 | ||||||
|         sd.outputAll(null); |         sd.outputAll(null); | ||||||
| 
 | 
 | ||||||
| @ -2627,7 +2621,8 @@ public class SameDiffTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|         SameDiff sd = SameDiff.create(); |         SameDiff sd = SameDiff.create(); | ||||||
|         SDVariable in = sd.placeHolder("in", DataType.FLOAT, 1, 3); |         SDVariable in = sd.placeHolder("in", DataType.FLOAT, 1, 3); | ||||||
|         SDVariable w = sd.constant("w", Nd4j.rand(DataType.FLOAT, 3, 4)); |         INDArray const1 =  Nd4j.rand(DataType.FLOAT, 3, 4); | ||||||
|  |         SDVariable w = sd.constant("w",const1); | ||||||
|         SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 4)); |         SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 4)); | ||||||
|         SDVariable mmul = in.mmul(w); |         SDVariable mmul = in.mmul(w); | ||||||
|         SDVariable add = mmul.add(b); |         SDVariable add = mmul.add(b); | ||||||
|  | |||||||
| @ -21,13 +21,18 @@ | |||||||
| package org.nd4j.autodiff.samediff.listeners; | package org.nd4j.autodiff.samediff.listeners; | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | import org.junit.jupiter.api.Disabled; | ||||||
|  | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.io.TempDir; | import org.junit.jupiter.api.io.TempDir; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener; | import org.nd4j.autodiff.listeners.checkpoint.CheckpointListener; | ||||||
| import org.nd4j.autodiff.samediff.SDVariable; | import org.nd4j.autodiff.samediff.SDVariable; | ||||||
| import org.nd4j.autodiff.samediff.SameDiff; | import org.nd4j.autodiff.samediff.SameDiff; | ||||||
| import org.nd4j.autodiff.samediff.TrainingConfig; | import org.nd4j.autodiff.samediff.TrainingConfig; | ||||||
|  | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.BaseNd4jTestWithBackends; | import org.nd4j.linalg.BaseNd4jTestWithBackends; | ||||||
| import org.nd4j.linalg.api.buffer.DataType; | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| import org.nd4j.linalg.dataset.IrisDataSetIterator; | import org.nd4j.linalg.dataset.IrisDataSetIterator; | ||||||
| @ -38,10 +43,7 @@ import org.nd4j.linalg.learning.config.Adam; | |||||||
| 
 | 
 | ||||||
| import java.io.File; | import java.io.File; | ||||||
| import java.nio.file.Path; | import java.nio.file.Path; | ||||||
| import java.util.Arrays; | import java.util.*; | ||||||
| import java.util.HashSet; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.Set; |  | ||||||
| import java.util.concurrent.TimeUnit; | import java.util.concurrent.TimeUnit; | ||||||
| 
 | 
 | ||||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | import static org.junit.jupiter.api.Assertions.assertEquals; | ||||||
| @ -169,8 +171,12 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Execution(ExecutionMode.SAME_THREAD) | ||||||
|  |     @Disabled("Inconsistent results on output") | ||||||
|  |     @Tag(TagNames.NEEDS_VERIFY) | ||||||
|     public void testCheckpointListenerEveryTimeUnit(Nd4jBackend backend) throws Exception { |     public void testCheckpointListenerEveryTimeUnit(Nd4jBackend backend) throws Exception { | ||||||
|         File dir = testDir.toFile(); |         File dir = testDir.resolve("new-dir-" + UUID.randomUUID().toString()).toFile(); | ||||||
|  |         assertTrue(dir.mkdirs()); | ||||||
|         SameDiff sd = getModel(); |         SameDiff sd = getModel(); | ||||||
| 
 | 
 | ||||||
|         CheckpointListener l = new CheckpointListener.Builder(dir) |         CheckpointListener l = new CheckpointListener.Builder(dir) | ||||||
| @ -181,9 +187,8 @@ public class CheckpointListenerTest extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|         DataSetIterator iter = getIter(15, 150); |         DataSetIterator iter = getIter(15, 150); | ||||||
| 
 | 
 | ||||||
|         for(int i=0; i<5; i++ ){   //10 iterations total |         for(int i = 0; i < 5; i++) {   //10 iterations total | ||||||
|             sd.fit(iter, 1); |             sd.fit(iter, 1); | ||||||
|             Thread.sleep(5000); |  | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         //Expect models saved at iterations: 10, 20, 30, 40 |         //Expect models saved at iterations: 10, 20, 30, 40 | ||||||
|  | |||||||
| @ -123,7 +123,7 @@ public class ListenerTest extends BaseNd4jTestWithBackends { | |||||||
| // | // | ||||||
| //        sd.evaluateMultiple(iter, evalMap); | //        sd.evaluateMultiple(iter, evalMap); | ||||||
| 
 | 
 | ||||||
|         e = (Evaluation) hist.finalTrainingEvaluations().evaluation(predictions); |         e = hist.finalTrainingEvaluations().evaluation(predictions); | ||||||
| 
 | 
 | ||||||
|         System.out.println(e.stats()); |         System.out.println(e.stats()); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -79,7 +79,7 @@ public class RegressionEvalTest  extends BaseNd4jTestWithBackends { | |||||||
|         RegressionEvaluation eval = new RegressionEvaluation(nCols); |         RegressionEvaluation eval = new RegressionEvaluation(nCols); | ||||||
| 
 | 
 | ||||||
|         for (int i = 0; i < nTestArrays; i++) { |         for (int i = 0; i < nTestArrays; i++) { | ||||||
|             INDArray rand = Nd4j.rand(valuesPerTestArray, nCols).castTo(DataType.DOUBLE); |             INDArray rand = Nd4j.rand(DataType.DOUBLE,valuesPerTestArray, nCols); | ||||||
|             eval.eval(rand, rand); |             eval.eval(rand, rand); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| @ -172,8 +172,8 @@ public class RegressionEvalTest  extends BaseNd4jTestWithBackends { | |||||||
|         for (int i = 0; i < nEvalInstances; i++) { |         for (int i = 0; i < nEvalInstances; i++) { | ||||||
|             list.add(new RegressionEvaluation(nCols)); |             list.add(new RegressionEvaluation(nCols)); | ||||||
|             for (int j = 0; j < numMinibatches; j++) { |             for (int j = 0; j < numMinibatches; j++) { | ||||||
|                 INDArray p = Nd4j.rand(nRows, nCols).castTo(Nd4j.defaultFloatingPointType()); |                 INDArray p = Nd4j.rand(DataType.DOUBLE,nRows, nCols); | ||||||
|                 INDArray act = Nd4j.rand(nRows, nCols).castTo(Nd4j.defaultFloatingPointType()); |                 INDArray act = Nd4j.rand(DataType.DOUBLE,nRows, nCols); | ||||||
| 
 | 
 | ||||||
|                 single.eval(act, p); |                 single.eval(act, p); | ||||||
| 
 | 
 | ||||||
| @ -383,7 +383,7 @@ public class RegressionEvalTest  extends BaseNd4jTestWithBackends { | |||||||
|         List<INDArray> rowsL = new ArrayList<>(); |         List<INDArray> rowsL = new ArrayList<>(); | ||||||
| 
 | 
 | ||||||
|         //Check per-example masking: |         //Check per-example masking: | ||||||
|         INDArray mask1dPerEx = Nd4j.createFromArray(1, 0); |         INDArray mask1dPerEx = Nd4j.createFromArray(1, 0).castTo(DataType.FLOAT); | ||||||
| 
 | 
 | ||||||
|         NdIndexIterator iter = new NdIndexIterator(2, 10, 10); |         NdIndexIterator iter = new NdIndexIterator(2, 10, 10); | ||||||
|         while (iter.hasNext()) { |         while (iter.hasNext()) { | ||||||
| @ -409,7 +409,7 @@ public class RegressionEvalTest  extends BaseNd4jTestWithBackends { | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         //Check per-output masking: |         //Check per-output masking: | ||||||
|         INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape()); |         INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape()).castTo(DataType.FLOAT); | ||||||
|         rowsP.clear(); |         rowsP.clear(); | ||||||
|         rowsL.clear(); |         rowsL.clear(); | ||||||
|         List<INDArray> rowsM = new ArrayList<>(); |         List<INDArray> rowsM = new ArrayList<>(); | ||||||
|  | |||||||
| @ -24,6 +24,8 @@ import lombok.extern.slf4j.Slf4j; | |||||||
| import org.junit.jupiter.api.AfterEach; | import org.junit.jupiter.api.AfterEach; | ||||||
| import org.junit.jupiter.api.BeforeEach; | import org.junit.jupiter.api.BeforeEach; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| 
 | 
 | ||||||
| @ -45,18 +47,15 @@ public class AveragingTests extends BaseNd4jTestWithBackends { | |||||||
|     private final int THREADS = 16; |     private final int THREADS = 16; | ||||||
|     private final int LENGTH = 51200 * 4; |     private final int LENGTH = 51200 * 4; | ||||||
| 
 | 
 | ||||||
|     DataType initialType = Nd4j.dataType(); |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void setUp() { |     public void setUp() { | ||||||
|         DataTypeUtil.setDTypeForContext(DataType.DOUBLE); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @AfterEach |     @AfterEach | ||||||
|     public void shutUp() { |     public void shutUp() { | ||||||
|         DataTypeUtil.setDTypeForContext(initialType); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -111,6 +110,7 @@ public class AveragingTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Execution(ExecutionMode.SAME_THREAD) | ||||||
|     public void testSingleDeviceAveraging2(Nd4jBackend backend) { |     public void testSingleDeviceAveraging2(Nd4jBackend backend) { | ||||||
|         INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH); |         INDArray exp = Nd4j.linspace(1, LENGTH, LENGTH); | ||||||
|         List<INDArray> arrays = new ArrayList<>(); |         List<INDArray> arrays = new ArrayList<>(); | ||||||
|  | |||||||
| @ -23,11 +23,13 @@ package org.nd4j.linalg; | |||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import lombok.val; | import lombok.val; | ||||||
| import org.apache.commons.lang3.RandomUtils; | import org.apache.commons.lang3.RandomUtils; | ||||||
|  | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| 
 | 
 | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
|  | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.api.buffer.DataType; | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; | import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; | ||||||
| @ -239,6 +241,8 @@ public class LoneTest extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|     public void testGetRow1(Nd4jBackend backend) { |     public void testGetRow1(Nd4jBackend backend) { | ||||||
|         INDArray array = Nd4j.create(10000, 10000); |         INDArray array = Nd4j.create(10000, 10000); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -29,6 +29,9 @@ import org.apache.commons.math3.util.FastMath; | |||||||
| import org.junit.jupiter.api.*; | import org.junit.jupiter.api.*; | ||||||
| 
 | 
 | ||||||
| import org.junit.jupiter.api.io.TempDir; | import org.junit.jupiter.api.io.TempDir; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
|  | import org.junit.jupiter.api.parallel.Isolated; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| 
 | 
 | ||||||
| @ -148,8 +151,6 @@ import static org.junit.jupiter.api.Assertions.*; | |||||||
| @Tag(TagNames.FILE_IO) | @Tag(TagNames.FILE_IO) | ||||||
| public class Nd4jTestsC extends BaseNd4jTestWithBackends { | public class Nd4jTestsC extends BaseNd4jTestWithBackends { | ||||||
| 
 | 
 | ||||||
|     DataType initialType = Nd4j.dataType(); |  | ||||||
|     Level1 l1 = Nd4j.getBlasWrapper().level1(); |  | ||||||
|     @TempDir Path testDir; |     @TempDir Path testDir; | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
| @ -159,7 +160,6 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void before() throws Exception { |     public void before() throws Exception { | ||||||
|         Nd4j.setDataType(DataType.DOUBLE); |  | ||||||
|         Nd4j.getRandom().setSeed(123); |         Nd4j.getRandom().setSeed(123); | ||||||
|         Nd4j.getExecutioner().enableDebugMode(false); |         Nd4j.getExecutioner().enableDebugMode(false); | ||||||
|         Nd4j.getExecutioner().enableVerboseMode(false); |         Nd4j.getExecutioner().enableVerboseMode(false); | ||||||
| @ -167,7 +167,6 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @AfterEach |     @AfterEach | ||||||
|     public void after() throws Exception { |     public void after() throws Exception { | ||||||
|         Nd4j.setDataType(initialType); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
| @ -1480,7 +1479,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}, new int[]{12}); |         INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}, new int[]{12}); | ||||||
|         INDArray flattened = Nd4j.toFlattened(concat); |         INDArray flattened = Nd4j.toFlattened(concat).castTo(assertion.dataType()); | ||||||
|         assertEquals(assertion, flattened); |         assertEquals(assertion, flattened); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -3902,6 +3901,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Disabled("Crashes") | ||||||
|  |     @Tag(TagNames.NEEDS_VERIFY) | ||||||
|     public void testSingleDeviceAveraging(Nd4jBackend backend) { |     public void testSingleDeviceAveraging(Nd4jBackend backend) { | ||||||
|         int LENGTH = 512 * 1024 * 2; |         int LENGTH = 512 * 1024 * 2; | ||||||
|         INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0); |         INDArray array1 = Nd4j.valueArrayOf(LENGTH, 1.0); | ||||||
| @ -5587,6 +5588,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Disabled("Crashes") | ||||||
|  |     @Tag(TagNames.NEEDS_VERIFY) | ||||||
|     public void testNativeSort3(Nd4jBackend backend) { |     public void testNativeSort3(Nd4jBackend backend) { | ||||||
|         int length = isIntegrationTests() ? 1048576 : 16484; |         int length = isIntegrationTests() ? 1048576 : 16484; | ||||||
|         INDArray array = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape(1, -1); |         INDArray array = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape(1, -1); | ||||||
| @ -5719,6 +5722,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Disabled("Crashes") | ||||||
|  |     @Tag(TagNames.NEEDS_VERIFY) | ||||||
|     public void testNativeSortAlongDimension1(Nd4jBackend backend) { |     public void testNativeSortAlongDimension1(Nd4jBackend backend) { | ||||||
|         INDArray array = Nd4j.create(1000, 1000); |         INDArray array = Nd4j.create(1000, 1000); | ||||||
|         INDArray exp1 = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE); |         INDArray exp1 = Nd4j.linspace(1, 1000, 1000, DataType.DOUBLE); | ||||||
| @ -5779,6 +5784,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Disabled("Crashes") | ||||||
|  |     @Tag(TagNames.NEEDS_VERIFY) | ||||||
|     public void testNativeSortAlongDimension3(Nd4jBackend backend) { |     public void testNativeSortAlongDimension3(Nd4jBackend backend) { | ||||||
|         INDArray array = Nd4j.create(2000,  2000); |         INDArray array = Nd4j.create(2000,  2000); | ||||||
|         INDArray exp1 = Nd4j.linspace(1, 2000, 2000, DataType.DOUBLE); |         INDArray exp1 = Nd4j.linspace(1, 2000, 2000, DataType.DOUBLE); | ||||||
| @ -5814,6 +5821,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Disabled("Crashes") | ||||||
|  |     @Tag(TagNames.NEEDS_VERIFY) | ||||||
|     public void testNativeSortAlongDimension2(Nd4jBackend backend) { |     public void testNativeSortAlongDimension2(Nd4jBackend backend) { | ||||||
|         INDArray array = Nd4j.create(100, 10); |         INDArray array = Nd4j.create(100, 10); | ||||||
|         INDArray exp1 = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); |         INDArray exp1 = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); | ||||||
| @ -6768,15 +6777,16 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testInconsistentOutput(){ |     @Execution(ExecutionMode.SAME_THREAD) | ||||||
|  |     public void testInconsistentOutput(Nd4jBackend backend) { | ||||||
|         INDArray in = Nd4j.rand(1, 802816).castTo(DataType.DOUBLE); |         INDArray in = Nd4j.rand(1, 802816).castTo(DataType.DOUBLE); | ||||||
|         INDArray W = Nd4j.rand(802816, 1).castTo(DataType.DOUBLE); |         INDArray W = Nd4j.rand(802816, 1).castTo(DataType.DOUBLE); | ||||||
|         INDArray b = Nd4j.create(1).castTo(DataType.DOUBLE); |         INDArray b = Nd4j.create(1).castTo(DataType.DOUBLE); | ||||||
|         INDArray out = fwd(in, W, b); |         INDArray out = fwd(in, W, b); | ||||||
| 
 | 
 | ||||||
|         for(int i = 0;i < 100;i++) { |         for(int i = 0; i < 100;i++) { | ||||||
|             INDArray out2 = fwd(in, W, b);  //l.activate(inToLayer1, false, LayerWorkspaceMgr.noWorkspaces()); |             INDArray out2 = fwd(in, W, b);  //l.activate(inToLayer1, false, LayerWorkspaceMgr.noWorkspaces()); | ||||||
|             assertEquals( out, out2,"Failed at iteration [" + String.valueOf(i) + "]"); |             assertEquals( out, out2,"Failed at iteration [" + i + "]"); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -7144,9 +7154,10 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testRowColumnOpsRank1(){ |     @Execution(ExecutionMode.SAME_THREAD) | ||||||
|  |     public void testRowColumnOpsRank1(Nd4jBackend backend) { | ||||||
| 
 | 
 | ||||||
|         for( int i=0; i<6; i++ ) { |         for( int i = 0; i < 6; i++ ) { | ||||||
|             INDArray orig = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); |             INDArray orig = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); | ||||||
|             INDArray in1r = orig.dup(); |             INDArray in1r = orig.dup(); | ||||||
|             INDArray in2r = orig.dup(); |             INDArray in2r = orig.dup(); | ||||||
| @ -7954,6 +7965,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Disabled("Crashes") | ||||||
|  |     @Tag(TagNames.NEEDS_VERIFY) | ||||||
|     public void testRollingMean(Nd4jBackend backend) { |     public void testRollingMean(Nd4jBackend backend) { | ||||||
|         val wsconf = WorkspaceConfiguration.builder() |         val wsconf = WorkspaceConfiguration.builder() | ||||||
|                 .initialSize(4L * (32*128*256*256 + 32*128 + 10*1024*1024)) |                 .initialSize(4L * (32*128*256*256 + 32*128 + 10*1024*1024)) | ||||||
| @ -8558,8 +8571,6 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     @Disabled("Needs verification") |  | ||||||
|     @Tag(TagNames.NEEDS_VERIFY) |  | ||||||
|     public void testBatchToSpace(Nd4jBackend backend) { |     public void testBatchToSpace(Nd4jBackend backend) { | ||||||
| 
 | 
 | ||||||
|         INDArray out = Nd4j.create(DataType.FLOAT, 2, 4, 5); |         INDArray out = Nd4j.create(DataType.FLOAT, 2, 4, 5); | ||||||
| @ -8833,7 +8844,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testCreateBufferFromByteBuffer(){ |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     public void testCreateBufferFromByteBuffer(Nd4jBackend backend){ | ||||||
| 
 | 
 | ||||||
|         for(DataType dt : DataType.values()){ |         for(DataType dt : DataType.values()){ | ||||||
|             if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN) |             if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN) | ||||||
|  | |||||||
| @ -41,6 +41,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; | |||||||
| import org.nd4j.nativeblas.NativeOpsHolder; | import org.nd4j.nativeblas.NativeOpsHolder; | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | import java.nio.Buffer; | ||||||
| import java.nio.ByteBuffer; | import java.nio.ByteBuffer; | ||||||
| import java.nio.ByteOrder; | import java.nio.ByteOrder; | ||||||
| 
 | 
 | ||||||
| @ -378,7 +379,8 @@ public class DataBufferTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|             INDArray arr2 = Nd4j.create(dt, arr.shape()); |             INDArray arr2 = Nd4j.create(dt, arr.shape()); | ||||||
|             ByteBuffer bb = arr2.data().pointer().asByteBuffer(); |             ByteBuffer bb = arr2.data().pointer().asByteBuffer(); | ||||||
|             bb.position(0); |             Buffer buffer = (Buffer) bb; | ||||||
|  |             buffer.position(0); | ||||||
|             bb.put(b); |             bb.put(b); | ||||||
| 
 | 
 | ||||||
|             Nd4j.getAffinityManager().tagLocation(arr2, AffinityManager.Location.HOST); |             Nd4j.getAffinityManager().tagLocation(arr2, AffinityManager.Location.HOST); | ||||||
|  | |||||||
| @ -59,18 +59,15 @@ import static org.junit.jupiter.api.Assertions.*; | |||||||
| @NativeTag | @NativeTag | ||||||
| public class FloatDataBufferTest extends BaseNd4jTestWithBackends { | public class FloatDataBufferTest extends BaseNd4jTestWithBackends { | ||||||
| 
 | 
 | ||||||
|     DataType initialType = Nd4j.dataType(); |  | ||||||
|     @TempDir Path tempDir; |     @TempDir Path tempDir; | ||||||
| 
 | 
 | ||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void before() { |     public void before() { | ||||||
|         DataTypeUtil.setDTypeForContext(DataType.FLOAT); |  | ||||||
|         System.out.println("DATATYPE HERE: " + Nd4j.dataType()); |         System.out.println("DATATYPE HERE: " + Nd4j.dataType()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @AfterEach |     @AfterEach | ||||||
|     public void after() { |     public void after() { | ||||||
|         DataTypeUtil.setDTypeForContext(initialType); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -191,7 +188,7 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends { | |||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testAsBytes(Nd4jBackend backend) { |     public void testAsBytes(Nd4jBackend backend) { | ||||||
|         INDArray arr = Nd4j.create(5); |         INDArray arr = Nd4j.create(DataType.FLOAT,5); | ||||||
|         byte[] d = arr.data().asBytes(); |         byte[] d = arr.data().asBytes(); | ||||||
|         assertEquals(4 * 5, d.length,getFailureMessage(backend)); |         assertEquals(4 * 5, d.length,getFailureMessage(backend)); | ||||||
|         INDArray rand = Nd4j.rand(3, 3); |         INDArray rand = Nd4j.rand(3, 3); | ||||||
| @ -245,7 +242,9 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends { | |||||||
|         buffer.reallocate(6); |         buffer.reallocate(6); | ||||||
|         float[] newBuf = buffer.asFloat(); |         float[] newBuf = buffer.asFloat(); | ||||||
|         assertEquals(6, buffer.capacity()); |         assertEquals(6, buffer.capacity()); | ||||||
|         assertArrayEquals(old, newBuf, 1e-4F); |         //note: old and new buf are not equal because java automatically populates the arrays with zeros | ||||||
|  |         //the new buffer is actually 1,2,3,4,0,0 because of this | ||||||
|  |         assertArrayEquals(new float[]{1,2,3,4,0,0}, newBuf, 1e-4F); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
| @ -253,17 +252,17 @@ public class FloatDataBufferTest extends BaseNd4jTestWithBackends { | |||||||
|     public void testReallocationWorkspace(Nd4jBackend backend) { |     public void testReallocationWorkspace(Nd4jBackend backend) { | ||||||
|         WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) |         WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L) | ||||||
|                         .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); |                         .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build(); | ||||||
|         MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); |         try(MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID")) { | ||||||
| 
 |             DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4}); | ||||||
|         DataBuffer buffer = Nd4j.createBuffer(new float[] {1, 2, 3, 4}); |             assertTrue(buffer.isAttached()); | ||||||
|         assertTrue(buffer.isAttached()); |             float[] old = buffer.asFloat(); | ||||||
|         float[] old = buffer.asFloat(); |             assertEquals(4, buffer.capacity()); | ||||||
|         assertEquals(4, buffer.capacity()); |             buffer.reallocate(6); | ||||||
|         buffer.reallocate(6); |             assertEquals(6, buffer.capacity()); | ||||||
|         assertEquals(6, buffer.capacity()); |             float[] newBuf = buffer.asFloat(); | ||||||
|         float[] newBuf = buffer.asFloat(); |             //note: java creates new zeros by default for empty array spots | ||||||
|         assertArrayEquals(old, newBuf, 1e-4F); |             assertArrayEquals(new float[]{1,2,3,4,0,0}, newBuf, 1e-4F); | ||||||
|         workspace.close(); |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|  | |||||||
| @ -175,9 +175,11 @@ public class BasicBroadcastTests extends BaseNd4jTestWithBackends { | |||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void basicBroadcastFailureTest_4(Nd4jBackend backend) { |     public void basicBroadcastFailureTest_4(Nd4jBackend backend) { | ||||||
|         val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); |         assertThrows(IllegalStateException.class,() -> { | ||||||
|         val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); |             val x = Nd4j.create(DataType.FLOAT, 3, 1, 2).assign(4.f); | ||||||
|         val z = x.addi(y); |             val y = Nd4j.createFromArray(new float[]{2.f, 2.f, 2.f, 2.f}).reshape(2, 2); | ||||||
|  |             val z = x.addi(y); | ||||||
|  |         }); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|  | |||||||
| @ -25,6 +25,9 @@ import lombok.val; | |||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.Disabled; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
|  | import org.junit.jupiter.api.parallel.Isolated; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| 
 | 
 | ||||||
| @ -52,6 +55,8 @@ import static org.junit.jupiter.api.Assertions.*; | |||||||
| @Slf4j | @Slf4j | ||||||
| @NativeTag | @NativeTag | ||||||
| @Tag(TagNames.COMPRESSION) | @Tag(TagNames.COMPRESSION) | ||||||
|  | @Isolated | ||||||
|  | @Execution(ExecutionMode.SAME_THREAD) | ||||||
| public class CompressionTests extends BaseNd4jTestWithBackends { | public class CompressionTests extends BaseNd4jTestWithBackends { | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -412,9 +417,11 @@ public class CompressionTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|     public void testBitmapEncoding2(Nd4jBackend backend) { |     public void testBitmapEncoding2(Nd4jBackend backend) { | ||||||
|         INDArray initial = Nd4j.create(40000000); |         INDArray initial = Nd4j.create(DataType.FLOAT,40000000); | ||||||
|         INDArray target = Nd4j.create(initial.length()); |         INDArray target = Nd4j.create(DataType.FLOAT,initial.length()); | ||||||
| 
 | 
 | ||||||
|         initial.addi(1e-3); |         initial.addi(1e-3); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -61,6 +61,7 @@ public class DeconvTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|     public void compareKeras(Nd4jBackend backend) throws Exception { |     public void compareKeras(Nd4jBackend backend) throws Exception { | ||||||
|         File newFolder = testDir.toFile(); |         File newFolder = testDir.toFile(); | ||||||
|         new ClassPathResource("keras/deconv/").copyDirectory(newFolder); |         new ClassPathResource("keras/deconv/").copyDirectory(newFolder); | ||||||
|  | |||||||
| @ -103,7 +103,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public char ordering(){ |     public char ordering() { | ||||||
|         return 'c'; |         return 'c'; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -566,7 +566,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testStridedSliceEdgeCase(){ |     public void testStridedSliceEdgeCase(Nd4jBackend backend) { | ||||||
|         INDArray in = Nd4j.scalar(10.0).reshape(1);   //Int [1] |         INDArray in = Nd4j.scalar(10.0).reshape(1);   //Int [1] | ||||||
|         INDArray begin = Nd4j.ones(DataType.INT, 1); |         INDArray begin = Nd4j.ones(DataType.INT, 1); | ||||||
|         INDArray end = Nd4j.zeros(DataType.INT, 1); |         INDArray end = Nd4j.zeros(DataType.INT, 1); | ||||||
| @ -595,7 +595,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testDepthwise(){ |     public void testDepthwise(Nd4jBackend backend) { | ||||||
|         INDArray input = Nd4j.create(DataType.DOUBLE, 1,3,8,8); |         INDArray input = Nd4j.create(DataType.DOUBLE, 1,3,8,8); | ||||||
|         INDArray depthwiseWeight = Nd4j.create(DataType.DOUBLE, 1,1,3,2); |         INDArray depthwiseWeight = Nd4j.create(DataType.DOUBLE, 1,1,3,2); | ||||||
|         INDArray bias = Nd4j.create(DataType.DOUBLE, 1, 6); |         INDArray bias = Nd4j.create(DataType.DOUBLE, 1, 6); | ||||||
| @ -660,8 +660,10 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
|         assertEquals(e, z); |         assertEquals(e, z); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Test() | 
 | ||||||
|     public void testInputValidationMergeMax(){ |     @ParameterizedTest | ||||||
|  |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     public void testInputValidationMergeMax(Nd4jBackend backend) { | ||||||
|         assertThrows(RuntimeException.class,() -> { |         assertThrows(RuntimeException.class,() -> { | ||||||
|             INDArray[] inputs = new INDArray[]{ |             INDArray[] inputs = new INDArray[]{ | ||||||
|                     Nd4j.createFromArray(0.0f, 1.0f, 2.0f).reshape('c', 1, 3), |                     Nd4j.createFromArray(0.0f, 1.0f, 2.0f).reshape('c', 1, 3), | ||||||
| @ -683,7 +685,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testUpsampling2dBackprop(){ |     public void testUpsampling2dBackprop(Nd4jBackend backend) { | ||||||
| 
 | 
 | ||||||
|         Nd4j.getRandom().setSeed(12345); |         Nd4j.getRandom().setSeed(12345); | ||||||
|         int c = 2; |         int c = 2; | ||||||
| @ -729,7 +731,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testIsMaxView(){ |     public void testIsMaxView(Nd4jBackend backend) { | ||||||
|         INDArray predictions = Nd4j.rand(DataType.FLOAT, 3, 4, 3, 2); |         INDArray predictions = Nd4j.rand(DataType.FLOAT, 3, 4, 3, 2); | ||||||
| 
 | 
 | ||||||
|         INDArray row = predictions.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)); |         INDArray row = predictions.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)); | ||||||
| @ -748,7 +750,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void isMax4d_2dims(){ |     public void isMax4d_2dims(Nd4jBackend backend) { | ||||||
|         Nd4j.getRandom().setSeed(12345); |         Nd4j.getRandom().setSeed(12345); | ||||||
|         INDArray in = Nd4j.rand(DataType.FLOAT, 3, 3, 4, 4).permute(0, 2, 3, 1); |         INDArray in = Nd4j.rand(DataType.FLOAT, 3, 3, 4, 4).permute(0, 2, 3, 1); | ||||||
| 
 | 
 | ||||||
| @ -764,7 +766,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testSizeTypes(){ |     public void testSizeTypes(Nd4jBackend backend) { | ||||||
|         List<DataType> failed = new ArrayList<>(); |         List<DataType> failed = new ArrayList<>(); | ||||||
|         for(DataType dt : new DataType[]{DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE, |         for(DataType dt : new DataType[]{DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE, | ||||||
|                 DataType.UINT64, DataType.UINT32, DataType.UINT16, DataType.UBYTE, |                 DataType.UINT64, DataType.UINT32, DataType.UINT16, DataType.UBYTE, | ||||||
| @ -796,7 +798,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testListDiff(){ |     public void testListDiff(Nd4jBackend backend) { | ||||||
|         INDArray x = Nd4j.createFromArray(0, 1, 2, 3); |         INDArray x = Nd4j.createFromArray(0, 1, 2, 3); | ||||||
|         INDArray y = Nd4j.createFromArray(3, 1); |         INDArray y = Nd4j.createFromArray(3, 1); | ||||||
| 
 | 
 | ||||||
| @ -817,7 +819,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testTopK1(){ |     public void testTopK1(Nd4jBackend backend) { | ||||||
|         INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); |         INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); | ||||||
|         INDArray k = Nd4j.scalar(1); |         INDArray k = Nd4j.scalar(1); | ||||||
|         INDArray outValue = Nd4j.create(DataType.DOUBLE, 1); |         INDArray outValue = Nd4j.create(DataType.DOUBLE, 1); | ||||||
| @ -897,7 +899,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testAdjustContrastShape(){ |     public void testAdjustContrastShape(Nd4jBackend backend) { | ||||||
|         DynamicCustomOp op = DynamicCustomOp.builder("adjust_contrast_v2") |         DynamicCustomOp op = DynamicCustomOp.builder("adjust_contrast_v2") | ||||||
|                 .addInputs(Nd4j.create(DataType.FLOAT, 256, 256,3), Nd4j.scalar(0.5f)) |                 .addInputs(Nd4j.create(DataType.FLOAT, 256, 256,3), Nd4j.scalar(0.5f)) | ||||||
|                 .build(); |                 .build(); | ||||||
| @ -910,7 +912,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testBitCastShape(){ |     public void testBitCastShape(Nd4jBackend backend) { | ||||||
|         INDArray out = Nd4j.createUninitialized(1,10); |         INDArray out = Nd4j.createUninitialized(1,10); | ||||||
|         BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out); |         BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out); | ||||||
|         List<LongShapeDescriptor> lsd = op.calculateOutputShape(); |         List<LongShapeDescriptor> lsd = op.calculateOutputShape(); | ||||||
| @ -1148,7 +1150,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testRange(){ |     public void testRange(Nd4jBackend backend) { | ||||||
|         DynamicCustomOp op = DynamicCustomOp.builder("range") |         DynamicCustomOp op = DynamicCustomOp.builder("range") | ||||||
|                 .addFloatingPointArguments(-1.0, 1.0, 0.01) |                 .addFloatingPointArguments(-1.0, 1.0, 0.01) | ||||||
|                 .build(); |                 .build(); | ||||||
| @ -1163,7 +1165,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testBitCastShape_1(){ |     public void testBitCastShape_1(Nd4jBackend backend) { | ||||||
|         val out = Nd4j.createUninitialized(1,10); |         val out = Nd4j.createUninitialized(1,10); | ||||||
|         BitCast op = new BitCast(Nd4j.zeros(DataType.FLOAT,1,10), DataType.INT.toInt(), out); |         BitCast op = new BitCast(Nd4j.zeros(DataType.FLOAT,1,10), DataType.INT.toInt(), out); | ||||||
|         List<LongShapeDescriptor> lsd = op.calculateOutputShape(); |         List<LongShapeDescriptor> lsd = op.calculateOutputShape(); | ||||||
| @ -1174,7 +1176,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testBitCastShape_2(){ |     public void testBitCastShape_2(Nd4jBackend backend) { | ||||||
|         val out = Nd4j.createUninitialized(1,10); |         val out = Nd4j.createUninitialized(1,10); | ||||||
|         BitCast op = new BitCast(Nd4j.zeros(DataType.DOUBLE,1,10), DataType.INT.toInt(), out); |         BitCast op = new BitCast(Nd4j.zeros(DataType.DOUBLE,1,10), DataType.INT.toInt(), out); | ||||||
|         List<LongShapeDescriptor> lsd = op.calculateOutputShape(); |         List<LongShapeDescriptor> lsd = op.calculateOutputShape(); | ||||||
| @ -1283,8 +1285,6 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     @Tag(TagNames.NEEDS_VERIFY) |  | ||||||
|     @Disabled("Implementation needs verification") |  | ||||||
|     public void testPolygamma(Nd4jBackend backend) { |     public void testPolygamma(Nd4jBackend backend) { | ||||||
|         INDArray n = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3,3); |         INDArray n = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 9).reshape(3,3); | ||||||
|         INDArray x = Nd4j.create(DataType.DOUBLE, 3,3); |         INDArray x = Nd4j.create(DataType.DOUBLE, 3,3); | ||||||
| @ -1292,7 +1292,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
|         INDArray expected = Nd4j.createFromArray(new double[]{4.934802, -16.828796, 97.409088, -771.474243, |         INDArray expected = Nd4j.createFromArray(new double[]{4.934802, -16.828796, 97.409088, -771.474243, | ||||||
|                 7691.113770f, -92203.460938f, 1290440.250000, -20644900.000000, 3.71595e+08}).reshape(3,3); |                 7691.113770f, -92203.460938f, 1290440.250000, -20644900.000000, 3.71595e+08}).reshape(3,3); | ||||||
|         INDArray output = Nd4j.create(DataType.DOUBLE, expected.shape()); |         INDArray output = Nd4j.create(DataType.DOUBLE, expected.shape()); | ||||||
|         val op = new Polygamma(x,n,output); |         val op = new Polygamma(n,x,output); | ||||||
|         Nd4j.exec(op); |         Nd4j.exec(op); | ||||||
|         assertEquals(expected, output); |         assertEquals(expected, output); | ||||||
|     } |     } | ||||||
| @ -1424,7 +1424,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testAdjustHueShape(){ |     public void testAdjustHueShape(Nd4jBackend backend) { | ||||||
|         INDArray image = Nd4j.createFromArray(new float[]{0.7788f,    0.8012f,    0.7244f, |         INDArray image = Nd4j.createFromArray(new float[]{0.7788f,    0.8012f,    0.7244f, | ||||||
|                 0.2309f,    0.7271f,    0.1804f, 0.5056f,    0.8925f,    0.5461f, |                 0.2309f,    0.7271f,    0.1804f, 0.5056f,    0.8925f,    0.5461f, | ||||||
|                 0.9234f,    0.0856f,    0.7938f, 0.6591f,    0.5555f,    0.1596f, |                 0.9234f,    0.0856f,    0.7938f, 0.6591f,    0.5555f,    0.1596f, | ||||||
| @ -1470,7 +1470,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testBitCastShape_3(){ |     public void testBitCastShape_3(Nd4jBackend backend) { | ||||||
|         val x = Nd4j.createFromArray(new int[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(1, 4, 2); |         val x = Nd4j.createFromArray(new int[]{1, 2, 3, 4, 5, 6, 7, 8}).reshape(1, 4, 2); | ||||||
|         val e = Nd4j.createFromArray(new long[]{8589934593L, 17179869187L, 25769803781L, 34359738375L}).reshape(1, 4); |         val e = Nd4j.createFromArray(new long[]{8589934593L, 17179869187L, 25769803781L, 34359738375L}).reshape(1, 4); | ||||||
|         val z = Nd4j.exec(new BitCast(x, DataType.LONG.toInt()))[0]; |         val z = Nd4j.exec(new BitCast(x, DataType.LONG.toInt()))[0]; | ||||||
| @ -1958,7 +1958,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testBatchNormBpNHWC(){ |     public void testBatchNormBpNHWC(Nd4jBackend backend) { | ||||||
|         //Nd4j.getEnvironment().allowHelpers(false);        //Passes if helpers/MKLDNN is disabled |         //Nd4j.getEnvironment().allowHelpers(false);        //Passes if helpers/MKLDNN is disabled | ||||||
| 
 | 
 | ||||||
|         INDArray in = Nd4j.rand(DataType.FLOAT, 2, 4, 4, 3); |         INDArray in = Nd4j.rand(DataType.FLOAT, 2, 4, 4, 3); | ||||||
| @ -1971,13 +1971,13 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|         assertEquals(eps, epsStrided); |         assertEquals(eps, epsStrided); | ||||||
| 
 | 
 | ||||||
|         INDArray out1eps = in.like(); |         INDArray out1eps = in.like().castTo(DataType.FLOAT); | ||||||
|         INDArray out1m = mean.like(); |         INDArray out1m = mean.like().castTo(DataType.FLOAT); | ||||||
|         INDArray out1v = var.like(); |         INDArray out1v = var.like().castTo(DataType.FLOAT); | ||||||
| 
 | 
 | ||||||
|         INDArray out2eps = in.like(); |         INDArray out2eps = in.like().castTo(DataType.FLOAT); | ||||||
|         INDArray out2m = mean.like(); |         INDArray out2m = mean.like().castTo(DataType.FLOAT); | ||||||
|         INDArray out2v = var.like(); |         INDArray out2v = var.like().castTo(DataType.FLOAT); | ||||||
| 
 | 
 | ||||||
|         DynamicCustomOp op1 = DynamicCustomOp.builder("batchnorm_bp") |         DynamicCustomOp op1 = DynamicCustomOp.builder("batchnorm_bp") | ||||||
|                 .addInputs(in, mean, var, gamma, beta, eps) |                 .addInputs(in, mean, var, gamma, beta, eps) | ||||||
| @ -2004,7 +2004,7 @@ public class CustomOpsTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testSpaceToDepthBadStrides(){ |     public void testSpaceToDepthBadStrides(Nd4jBackend backend) { | ||||||
|         INDArray in = Nd4j.rand(DataType.FLOAT, 2, 3, 6, 6); |         INDArray in = Nd4j.rand(DataType.FLOAT, 2, 3, 6, 6); | ||||||
|         INDArray inBadStrides = in.permute(1,0,2,3).dup().permute(1,0,2,3); |         INDArray inBadStrides = in.permute(1,0,2,3).dup().permute(1,0,2,3); | ||||||
|         assertEquals(in, inBadStrides); |         assertEquals(in, inBadStrides); | ||||||
|  | |||||||
| @ -24,6 +24,7 @@ import lombok.Getter; | |||||||
| import org.junit.jupiter.api.BeforeEach; | import org.junit.jupiter.api.BeforeEach; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.junit.jupiter.api.io.TempDir; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| 
 | 
 | ||||||
| @ -45,10 +46,13 @@ import org.nd4j.linalg.dataset.api.preprocessor.stats.DistributionStats; | |||||||
| import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats; | import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats; | ||||||
| import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats; | import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
|  | import org.nd4j.linalg.factory.Nd4jBackend; | ||||||
| 
 | 
 | ||||||
| import java.io.*; | import java.io.*; | ||||||
|  | import java.nio.file.Files; | ||||||
| import java.util.HashMap; | import java.util.HashMap; | ||||||
| import java.util.Map; | import java.util.Map; | ||||||
|  | import java.util.UUID; | ||||||
| 
 | 
 | ||||||
| import static java.util.Arrays.asList; | import static java.util.Arrays.asList; | ||||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | import static org.junit.jupiter.api.Assertions.assertEquals; | ||||||
| @ -61,83 +65,91 @@ import static org.junit.jupiter.api.Assertions.assertThrows; | |||||||
| @NativeTag | @NativeTag | ||||||
| @Tag(TagNames.FILE_IO) | @Tag(TagNames.FILE_IO) | ||||||
| public class NormalizerSerializerTest extends BaseNd4jTestWithBackends { | public class NormalizerSerializerTest extends BaseNd4jTestWithBackends { | ||||||
|     private File tmpFile; |      @TempDir  File tmpFile; | ||||||
|     private NormalizerSerializer SUT; |     private NormalizerSerializer SUT; | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void setUp() throws IOException { |     public void setUp() throws IOException { | ||||||
|         tmpFile = File.createTempFile("test", "preProcessor"); |  | ||||||
|         tmpFile.deleteOnExit(); |  | ||||||
| 
 |  | ||||||
|         SUT = NormalizerSerializer.getDefault(); |         SUT = NormalizerSerializer.getDefault(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testImagePreProcessingScaler() throws Exception { |     public void testImagePreProcessingScaler(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|         ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(0,1); |         ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(0,1); | ||||||
|         SUT.write(imagePreProcessingScaler,tmpFile); |         SUT.write(imagePreProcessingScaler,normalizerFile); | ||||||
| 
 | 
 | ||||||
|         ImagePreProcessingScaler restored = SUT.restore(tmpFile); |         ImagePreProcessingScaler restored = SUT.restore(normalizerFile); | ||||||
|         assertEquals(imagePreProcessingScaler,restored); |         assertEquals(imagePreProcessingScaler,restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testNormalizerStandardizeNotFitLabels() throws Exception { |     public void testNormalizerStandardizeNotFitLabels(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), |         NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), | ||||||
|                 Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)); |                 Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         NormalizerStandardize restored = SUT.restore(tmpFile); |         NormalizerStandardize restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testNormalizerStandardizeFitLabels() throws Exception { |     public void testNormalizerStandardizeFitLabels(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), |         NormalizerStandardize original = new NormalizerStandardize(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), | ||||||
|                 Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1), Nd4j.create(new double[] {4.5, 5.5}).reshape(1, -1), |                 Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1), Nd4j.create(new double[] {4.5, 5.5}).reshape(1, -1), | ||||||
|                 Nd4j.create(new double[] {6.5, 7.5}).reshape(1, -1)); |                 Nd4j.create(new double[] {6.5, 7.5}).reshape(1, -1)); | ||||||
|         original.fitLabel(true); |         original.fitLabel(true); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         NormalizerStandardize restored = SUT.restore(tmpFile); |         NormalizerStandardize restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testNormalizerMinMaxScalerNotFitLabels() throws Exception { |     public void testNormalizerMinMaxScalerNotFitLabels(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9); |         NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9); | ||||||
|         original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)); |         original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), Nd4j.create(new double[] {2.5, 3.5}).reshape(1, -1)); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         NormalizerMinMaxScaler restored = SUT.restore(tmpFile); |         NormalizerMinMaxScaler restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testNormalizerMinMaxScalerFitLabels() throws Exception { |     public void testNormalizerMinMaxScalerFitLabels(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9); |         NormalizerMinMaxScaler original = new NormalizerMinMaxScaler(0.1, 0.9); | ||||||
|         original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})); |         original.setFeatureStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})); | ||||||
|         original.setLabelStats(Nd4j.create(new double[] {4.5, 5.5}), Nd4j.create(new double[] {6.5, 7.5})); |         original.setLabelStats(Nd4j.create(new double[] {4.5, 5.5}), Nd4j.create(new double[] {6.5, 7.5})); | ||||||
|         original.fitLabel(true); |         original.fitLabel(true); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         NormalizerMinMaxScaler restored = SUT.restore(tmpFile); |         NormalizerMinMaxScaler restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testMultiNormalizerStandardizeNotFitLabels() throws Exception { |     public void testMultiNormalizerStandardizeNotFitLabels(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         MultiNormalizerStandardize original = new MultiNormalizerStandardize(); |         MultiNormalizerStandardize original = new MultiNormalizerStandardize(); | ||||||
|         original.setFeatureStats(asList( |         original.setFeatureStats(asList( | ||||||
|                 new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), |                 new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), | ||||||
| @ -145,15 +157,17 @@ public class NormalizerSerializerTest extends BaseNd4jTestWithBackends { | |||||||
|                 new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}).reshape(1, -1), |                 new DistributionStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}).reshape(1, -1), | ||||||
|                         Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1)))); |                         Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1)))); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         MultiNormalizerStandardize restored = SUT.restore(tmpFile); |         MultiNormalizerStandardize restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testMultiNormalizerStandardizeFitLabels() throws Exception { |     public void testMultiNormalizerStandardizeFitLabels(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         MultiNormalizerStandardize original = new MultiNormalizerStandardize(); |         MultiNormalizerStandardize original = new MultiNormalizerStandardize(); | ||||||
|         original.setFeatureStats(asList( |         original.setFeatureStats(asList( | ||||||
|                 new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), |                 new DistributionStats(Nd4j.create(new double[] {0.5, 1.5}).reshape(1, -1), | ||||||
| @ -168,30 +182,34 @@ public class NormalizerSerializerTest extends BaseNd4jTestWithBackends { | |||||||
|                         Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1)))); |                         Nd4j.create(new double[] {7.5, 8.5, 9.5}).reshape(1, -1)))); | ||||||
|         original.fitLabel(true); |         original.fitLabel(true); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         MultiNormalizerStandardize restored = SUT.restore(tmpFile); |         MultiNormalizerStandardize restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testMultiNormalizerMinMaxScalerNotFitLabels() throws Exception { |     public void testMultiNormalizerMinMaxScalerNotFitLabels(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9); |         MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9); | ||||||
|         original.setFeatureStats(asList( |         original.setFeatureStats(asList( | ||||||
|                 new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})), |                 new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})), | ||||||
|                 new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}), |                 new MinMaxStats(Nd4j.create(new double[] {4.5, 5.5, 6.5}), | ||||||
|                         Nd4j.create(new double[] {7.5, 8.5, 9.5})))); |                         Nd4j.create(new double[] {7.5, 8.5, 9.5})))); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         MultiNormalizerMinMaxScaler restored = SUT.restore(tmpFile); |         MultiNormalizerMinMaxScaler restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testMultiNormalizerMinMaxScalerFitLabels() throws Exception { |     public void testMultiNormalizerMinMaxScalerFitLabels(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9); |         MultiNormalizerMinMaxScaler original = new MultiNormalizerMinMaxScaler(0.1, 0.9); | ||||||
|         original.setFeatureStats(asList( |         original.setFeatureStats(asList( | ||||||
|                 new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})), |                 new MinMaxStats(Nd4j.create(new double[] {0.5, 1.5}), Nd4j.create(new double[] {2.5, 3.5})), | ||||||
| @ -204,28 +222,32 @@ public class NormalizerSerializerTest extends BaseNd4jTestWithBackends { | |||||||
|                         Nd4j.create(new double[] {7.5, 8.5, 9.5})))); |                         Nd4j.create(new double[] {7.5, 8.5, 9.5})))); | ||||||
|         original.fitLabel(true); |         original.fitLabel(true); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         MultiNormalizerMinMaxScaler restored = SUT.restore(tmpFile); |         MultiNormalizerMinMaxScaler restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testMultiNormalizerHybridEmpty() throws Exception { |     public void testMultiNormalizerHybridEmpty(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         MultiNormalizerHybrid original = new MultiNormalizerHybrid(); |         MultiNormalizerHybrid original = new MultiNormalizerHybrid(); | ||||||
|         original.setInputStats(new HashMap<Integer, NormalizerStats>()); |         original.setInputStats(new HashMap<>()); | ||||||
|         original.setOutputStats(new HashMap<Integer, NormalizerStats>()); |         original.setOutputStats(new HashMap<>()); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         MultiNormalizerHybrid restored = SUT.restore(tmpFile); |         MultiNormalizerHybrid restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testMultiNormalizerHybridGlobalStats() throws Exception { |     public void testMultiNormalizerHybridGlobalStats(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         MultiNormalizerHybrid original = new MultiNormalizerHybrid().minMaxScaleAllInputs().standardizeAllOutputs(); |         MultiNormalizerHybrid original = new MultiNormalizerHybrid().minMaxScaleAllInputs().standardizeAllOutputs(); | ||||||
| 
 | 
 | ||||||
|         Map<Integer, NormalizerStats> inputStats = new HashMap<>(); |         Map<Integer, NormalizerStats> inputStats = new HashMap<>(); | ||||||
| @ -239,15 +261,17 @@ public class NormalizerSerializerTest extends BaseNd4jTestWithBackends { | |||||||
|         original.setInputStats(inputStats); |         original.setInputStats(inputStats); | ||||||
|         original.setOutputStats(outputStats); |         original.setOutputStats(outputStats); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         MultiNormalizerHybrid restored = SUT.restore(tmpFile); |         MultiNormalizerHybrid restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testMultiNormalizerHybridGlobalAndSpecificStats() throws Exception { |     public void testMultiNormalizerHybridGlobalAndSpecificStats(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         MultiNormalizerHybrid original = new MultiNormalizerHybrid().standardizeAllInputs().minMaxScaleInput(0, -5, 5) |         MultiNormalizerHybrid original = new MultiNormalizerHybrid().standardizeAllInputs().minMaxScaleInput(0, -5, 5) | ||||||
|                 .minMaxScaleAllOutputs(-10, 10).standardizeOutput(1); |                 .minMaxScaleAllOutputs(-10, 10).standardizeOutput(1); | ||||||
| 
 | 
 | ||||||
| @ -262,29 +286,35 @@ public class NormalizerSerializerTest extends BaseNd4jTestWithBackends { | |||||||
|         original.setInputStats(inputStats); |         original.setInputStats(inputStats); | ||||||
|         original.setOutputStats(outputStats); |         original.setOutputStats(outputStats); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         MultiNormalizerHybrid restored = SUT.restore(tmpFile); |         MultiNormalizerHybrid restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Test() | 
 | ||||||
|     public void testCustomNormalizerWithoutRegisteredStrategy() throws Exception { |     @ParameterizedTest | ||||||
|  |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     public void testCustomNormalizerWithoutRegisteredStrategy(Nd4jBackend backend) throws Exception { | ||||||
|         assertThrows(RuntimeException.class, () -> { |         assertThrows(RuntimeException.class, () -> { | ||||||
|             SUT.write(new MyNormalizer(123), tmpFile); |             File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|  |             SUT.write(new MyNormalizer(123), normalizerFile); | ||||||
| 
 | 
 | ||||||
|         }); |         }); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testCustomNormalizer() throws Exception { |     public void testCustomNormalizer(Nd4jBackend backend) throws Exception { | ||||||
|  |         File normalizerFile = Files.createTempFile(tmpFile.toPath(),"pre-process-" + UUID.randomUUID().toString(),"bin").toFile(); | ||||||
|  | 
 | ||||||
|         MyNormalizer original = new MyNormalizer(42); |         MyNormalizer original = new MyNormalizer(42); | ||||||
| 
 | 
 | ||||||
|         SUT.addStrategy(new MyNormalizerSerializerStrategy()); |         SUT.addStrategy(new MyNormalizerSerializerStrategy()); | ||||||
| 
 | 
 | ||||||
|         SUT.write(original, tmpFile); |         SUT.write(original, normalizerFile); | ||||||
|         MyNormalizer restored = SUT.restore(tmpFile); |         MyNormalizer restored = SUT.restore(normalizerFile); | ||||||
| 
 | 
 | ||||||
|         assertEquals(original, restored); |         assertEquals(original, restored); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -42,6 +42,7 @@ import org.nd4j.common.util.ArrayUtil; | |||||||
| import org.nd4j.nativeblas.NativeOpsHolder; | import org.nd4j.nativeblas.NativeOpsHolder; | ||||||
| 
 | 
 | ||||||
| import java.io.File; | import java.io.File; | ||||||
|  | import java.nio.Buffer; | ||||||
| import java.nio.ByteBuffer; | import java.nio.ByteBuffer; | ||||||
| import java.util.ArrayList; | import java.util.ArrayList; | ||||||
| import java.util.Arrays; | import java.util.Arrays; | ||||||
| @ -248,7 +249,8 @@ public class Nd4jTest extends BaseNd4jTestWithBackends { | |||||||
|         byte[] dataTwo = new byte[floatBuffer.capacity()]; |         byte[] dataTwo = new byte[floatBuffer.capacity()]; | ||||||
|         floatBuffer.get(dataTwo); |         floatBuffer.get(dataTwo); | ||||||
|         assertArrayEquals(originalData,dataTwo); |         assertArrayEquals(originalData,dataTwo); | ||||||
|         floatBuffer.position(0); |         Buffer buffer = (Buffer) floatBuffer; | ||||||
|  |         buffer.position(0); | ||||||
| 
 | 
 | ||||||
|         DataBuffer dataBuffer = Nd4j.createBuffer(new FloatPointer(floatBuffer.asFloatBuffer()),linspace.length(), DataType.FLOAT); |         DataBuffer dataBuffer = Nd4j.createBuffer(new FloatPointer(floatBuffer.asFloatBuffer()),linspace.length(), DataType.FLOAT); | ||||||
|         assertArrayEquals(new float[]{1,2,3,4}, dataBuffer.asFloat(), 1e-5f); |         assertArrayEquals(new float[]{1,2,3,4}, dataBuffer.asFloat(), 1e-5f); | ||||||
|  | |||||||
| @ -23,6 +23,8 @@ package org.nd4j.linalg.ops; | |||||||
| import lombok.val; | import lombok.val; | ||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.Disabled; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| 
 | 
 | ||||||
| @ -116,8 +118,10 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     @Test |     @ParameterizedTest | ||||||
|     public void testDistance() throws Exception { |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Execution(ExecutionMode.SAME_THREAD) | ||||||
|  |     public void testDistance(Nd4jBackend backend) throws Exception { | ||||||
|         INDArray matrix = Nd4j.rand(new int[] {400,10}); |         INDArray matrix = Nd4j.rand(new int[] {400,10}); | ||||||
|         INDArray rowVector = matrix.getRow(70); |         INDArray rowVector = matrix.getRow(70); | ||||||
|         INDArray resultArr = Nd4j.zeros(400,1); |         INDArray resultArr = Nd4j.zeros(400,1); | ||||||
| @ -127,8 +131,6 @@ public class OpExecutionerTests extends BaseNd4jTestWithBackends { | |||||||
|             System.out.println("Ran!"); |             System.out.println("Ran!"); | ||||||
|         }); |         }); | ||||||
| 
 | 
 | ||||||
|         Thread.sleep(600000); |  | ||||||
| 
 |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|  | |||||||
| @ -82,11 +82,9 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point; | |||||||
| @Slf4j | @Slf4j | ||||||
| @NativeTag | @NativeTag | ||||||
| public class OpExecutionerTestsC extends BaseNd4jTestWithBackends { | public class OpExecutionerTestsC extends BaseNd4jTestWithBackends { | ||||||
|     DataType initialType = Nd4j.dataType(); |  | ||||||
| 
 | 
 | ||||||
|     @AfterEach |     @AfterEach | ||||||
|     public void after() { |     public void after() { | ||||||
|         Nd4j.setDataType(this.initialType); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -23,6 +23,8 @@ package org.nd4j.linalg.profiling; | |||||||
| import org.junit.jupiter.api.AfterEach; | import org.junit.jupiter.api.AfterEach; | ||||||
| import org.junit.jupiter.api.BeforeEach; | import org.junit.jupiter.api.BeforeEach; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| 
 | 
 | ||||||
| @ -39,6 +41,7 @@ import org.nd4j.linalg.profiler.ProfilerConfig; | |||||||
| import static org.junit.jupiter.api.Assertions.assertThrows; | import static org.junit.jupiter.api.Assertions.assertThrows; | ||||||
| 
 | 
 | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Execution(ExecutionMode.SAME_THREAD) | ||||||
| public class InfNanTests extends BaseNd4jTestWithBackends { | public class InfNanTests extends BaseNd4jTestWithBackends { | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -27,6 +27,9 @@ import org.junit.jupiter.api.AfterEach; | |||||||
| import org.junit.jupiter.api.BeforeEach; | import org.junit.jupiter.api.BeforeEach; | ||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.Disabled; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
|  | import org.junit.jupiter.api.parallel.Isolated; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| import org.nd4j.common.tests.tags.NativeTag; | import org.nd4j.common.tests.tags.NativeTag; | ||||||
| @ -52,6 +55,8 @@ import static org.junit.jupiter.api.Assertions.*; | |||||||
| 
 | 
 | ||||||
| @Slf4j | @Slf4j | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Isolated | ||||||
|  | @Execution(ExecutionMode.SAME_THREAD) | ||||||
| public class OperationProfilerTests extends BaseNd4jTestWithBackends { | public class OperationProfilerTests extends BaseNd4jTestWithBackends { | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -229,9 +234,10 @@ public class OperationProfilerTests extends BaseNd4jTestWithBackends { | |||||||
|         assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_NON_EWS_ACCESS)); |         assertTrue(ArrayUtils.contains(causes, OpProfiler.PenaltyCause.TAD_NON_EWS_ACCESS)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Test |     @ParameterizedTest | ||||||
|  |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testBadTad4(Nd4jBackend backend) { |     public void testBadTad4(Nd4jBackend backend) { | ||||||
|         INDArray x = Nd4j.create(2, 4, 5, 6); |         INDArray x = Nd4j.create(DataType.DOUBLE,2, 4, 5, 6); | ||||||
| 
 | 
 | ||||||
|         Pair<DataBuffer, DataBuffer> pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 3); |         Pair<DataBuffer, DataBuffer> pair = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, 3); | ||||||
| 
 | 
 | ||||||
| @ -473,7 +479,7 @@ public class OperationProfilerTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testNanPanic(){ |     public void testNanPanic(Nd4jBackend backend) { | ||||||
|         try { |         try { | ||||||
|             DynamicCustomOp op = DynamicCustomOp.builder("add") |             DynamicCustomOp op = DynamicCustomOp.builder("add") | ||||||
|                     .addInputs(Nd4j.valueArrayOf(10, Double.NaN).castTo(DataType.DOUBLE), Nd4j.scalar(0.0)) |                     .addInputs(Nd4j.valueArrayOf(10, Double.NaN).castTo(DataType.DOUBLE), Nd4j.scalar(0.0)) | ||||||
|  | |||||||
| @ -441,6 +441,7 @@ public class RandomTests extends BaseNd4jTestWithBackends { | |||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     @Tag(TagNames.LONG_TEST) |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|     public void testStepOver1(Nd4jBackend backend) { |     public void testStepOver1(Nd4jBackend backend) { | ||||||
|         Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); |         Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); | ||||||
| 
 | 
 | ||||||
| @ -466,6 +467,8 @@ public class RandomTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|     public void testSum_119(Nd4jBackend backend) { |     public void testSum_119(Nd4jBackend backend) { | ||||||
|         INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000); |         INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000); | ||||||
|         val sum = z2.sumNumber().doubleValue(); |         val sum = z2.sumNumber().doubleValue(); | ||||||
| @ -474,6 +477,8 @@ public class RandomTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|     public void testLegacyDistribution1(Nd4jBackend backend) { |     public void testLegacyDistribution1(Nd4jBackend backend) { | ||||||
|         NormalDistribution distribution = new NormalDistribution(new DefaultRandom(), 0.0, 1.0); |         NormalDistribution distribution = new NormalDistribution(new DefaultRandom(), 0.0, 1.0); | ||||||
|         INDArray z1 = distribution.sample(new int[] {1, 1000000}); |         INDArray z1 = distribution.sample(new int[] {1, 1000000}); | ||||||
| @ -923,9 +928,10 @@ public class RandomTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|     public void testDeallocation1() throws Exception { |     public void testDeallocation1() throws Exception { | ||||||
| 
 |         for(int i = 0; i < 1000; i++) { | ||||||
|         while (true) { |  | ||||||
|             Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); |             Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119); | ||||||
|             random1.nextInt(); |             random1.nextInt(); | ||||||
| 
 | 
 | ||||||
| @ -934,6 +940,7 @@ public class RandomTests extends BaseNd4jTestWithBackends { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void someTest(Nd4jBackend backend) { |     public void someTest(Nd4jBackend backend) { | ||||||
|  | |||||||
| @ -29,6 +29,8 @@ import lombok.extern.slf4j.Slf4j; | |||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.Disabled; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| import org.nd4j.common.base.Preconditions; | import org.nd4j.common.base.Preconditions; | ||||||
| @ -70,6 +72,7 @@ import java.util.Map; | |||||||
| @Slf4j | @Slf4j | ||||||
| @NativeTag | @NativeTag | ||||||
| @Tag(TagNames.RNG) | @Tag(TagNames.RNG) | ||||||
|  | @Execution(ExecutionMode.SAME_THREAD) | ||||||
| public class RngValidationTests extends BaseNd4jTestWithBackends { | public class RngValidationTests extends BaseNd4jTestWithBackends { | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -129,6 +132,8 @@ public class RngValidationTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Disabled | ||||||
|  |     @Tag(TagNames.NEEDS_VERIFY) | ||||||
|     public void validateRngDistributions(Nd4jBackend backend){ |     public void validateRngDistributions(Nd4jBackend backend){ | ||||||
|         List<TestCase> testCases = new ArrayList<>(); |         List<TestCase> testCases = new ArrayList<>(); | ||||||
|         for(DataType type : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { |         for(DataType type : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { | ||||||
| @ -264,7 +269,7 @@ public class RngValidationTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|         int count = 1; |         int count = 1; | ||||||
|         for(TestCase tc : testCases){ |         for(TestCase tc : testCases) { | ||||||
|             log.info("Starting test case: {} of {}", count, testCases.size()); |             log.info("Starting test case: {} of {}", count, testCases.size()); | ||||||
|             log.info("{}", tc); |             log.info("{}", tc); | ||||||
| 
 | 
 | ||||||
| @ -314,7 +319,7 @@ public class RngValidationTests extends BaseNd4jTestWithBackends { | |||||||
|             assertEquals(z, z2); |             assertEquals(z, z2); | ||||||
| 
 | 
 | ||||||
|             //Check mean, stdev |             //Check mean, stdev | ||||||
|             if(tc.getExpectedMean() != null){ |             if(tc.getExpectedMean() != null) { | ||||||
|                 double mean = z.meanNumber().doubleValue(); |                 double mean = z.meanNumber().doubleValue(); | ||||||
|                 double re = relError(tc.getExpectedMean(), mean); |                 double re = relError(tc.getExpectedMean(), mean); | ||||||
|                 double ae = Math.abs(tc.getExpectedMean() - mean); |                 double ae = Math.abs(tc.getExpectedMean() - mean); | ||||||
|  | |||||||
| @ -44,9 +44,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | |||||||
| @Slf4j | @Slf4j | ||||||
| @Tag(TagNames.JACKSON_SERDE) | @Tag(TagNames.JACKSON_SERDE) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
|  | @Tag(TagNames.LONG_TEST) | ||||||
| public class LargeSerDeTests extends BaseNd4jTestWithBackends { | public class LargeSerDeTests extends BaseNd4jTestWithBackends { | ||||||
| 
 | 
 | ||||||
|       @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testLargeArraySerDe_1(Nd4jBackend backend) throws Exception { |     public void testLargeArraySerDe_1(Nd4jBackend backend) throws Exception { | ||||||
|         val arrayA = Nd4j.rand(new long[] {1, 135079944}); |         val arrayA = Nd4j.rand(new long[] {1, 135079944}); | ||||||
|  | |||||||
| @ -42,6 +42,7 @@ import org.nd4j.common.io.ClassPathResource; | |||||||
| import java.io.File; | import java.io.File; | ||||||
| import java.nio.file.Path; | import java.nio.file.Path; | ||||||
| import java.util.Map; | import java.util.Map; | ||||||
|  | import java.util.UUID; | ||||||
| 
 | 
 | ||||||
| import static org.junit.jupiter.api.Assertions.*; | import static org.junit.jupiter.api.Assertions.*; | ||||||
| 
 | 
 | ||||||
| @ -56,7 +57,8 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends { | |||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testToNpyFormat(Nd4jBackend backend) throws Exception { |     public void testToNpyFormat(Nd4jBackend backend) throws Exception { | ||||||
| 
 | 
 | ||||||
|         val dir = testDir.toFile(); |         val dir = testDir.resolve("new-dir-" + UUID.randomUUID().toString()).toFile(); | ||||||
|  |         assertTrue(dir.mkdirs()); | ||||||
|         new ClassPathResource("numpy_arrays/").copyDirectory(dir); |         new ClassPathResource("numpy_arrays/").copyDirectory(dir); | ||||||
| 
 | 
 | ||||||
|         File[] files = dir.listFiles(); |         File[] files = dir.listFiles(); | ||||||
| @ -107,14 +109,15 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends { | |||||||
|     public void testToNpyFormatScalars(Nd4jBackend backend) throws Exception { |     public void testToNpyFormatScalars(Nd4jBackend backend) throws Exception { | ||||||
| //        File dir = new File("C:\\DL4J\\Git\\dl4j-test-resources\\src\\main\\resources\\numpy_arrays\\scalar"); | //        File dir = new File("C:\\DL4J\\Git\\dl4j-test-resources\\src\\main\\resources\\numpy_arrays\\scalar"); | ||||||
| 
 | 
 | ||||||
|         val dir = testDir.toFile(); |         val dir = testDir.resolve("new-path0" + UUID.randomUUID().toString()).toFile(); | ||||||
|  |         dir.mkdirs(); | ||||||
|         new ClassPathResource("numpy_arrays/scalar/").copyDirectory(dir); |         new ClassPathResource("numpy_arrays/scalar/").copyDirectory(dir); | ||||||
| 
 | 
 | ||||||
|         File[] files = dir.listFiles(); |         File[] files = dir.listFiles(); | ||||||
|         int cnt = 0; |         int cnt = 0; | ||||||
| 
 | 
 | ||||||
|         for(File f : files){ |         for(File f : files){ | ||||||
|             if(!f.getPath().endsWith(".npy")){ |             if(!f.getPath().endsWith(".npy")) { | ||||||
|                 log.warn("Skipping: {}", f); |                 log.warn("Skipping: {}", f); | ||||||
|                 continue; |                 continue; | ||||||
|             } |             } | ||||||
| @ -161,7 +164,8 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends { | |||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testNpzReading(Nd4jBackend backend) throws Exception { |     public void testNpzReading(Nd4jBackend backend) throws Exception { | ||||||
| 
 | 
 | ||||||
|         val dir = testDir.toFile(); |         val dir = testDir.resolve("new-folder-npz").toFile(); | ||||||
|  |         dir.mkdirs(); | ||||||
|         new ClassPathResource("numpy_arrays/npz/").copyDirectory(dir); |         new ClassPathResource("numpy_arrays/npz/").copyDirectory(dir); | ||||||
| 
 | 
 | ||||||
|         File[] files = dir.listFiles(); |         File[] files = dir.listFiles(); | ||||||
| @ -222,7 +226,8 @@ public class NumpyFormatTests extends BaseNd4jTestWithBackends { | |||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testNpy(Nd4jBackend backend) throws Exception { |     public void testNpy(Nd4jBackend backend) throws Exception { | ||||||
|         for(boolean empty : new boolean[]{false, true}) { |         for(boolean empty : new boolean[]{false, true}) { | ||||||
|             val dir = testDir.toFile(); |             val dir = testDir.resolve("new-dir-1-" + UUID.randomUUID().toString()).toFile(); | ||||||
|  |             assertTrue(dir.mkdirs()); | ||||||
|             if(!empty) { |             if(!empty) { | ||||||
|                 new ClassPathResource("numpy_arrays/npy/3,4/").copyDirectory(dir); |                 new ClassPathResource("numpy_arrays/npy/3,4/").copyDirectory(dir); | ||||||
|             } else { |             } else { | ||||||
|  | |||||||
| @ -403,13 +403,13 @@ public class ShapeTestsC extends BaseNd4jTestWithBackends { | |||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testRavel(Nd4jBackend backend) { |     public void testRavel(Nd4jBackend backend) { | ||||||
|         INDArray linspace = Nd4j.linspace(1, 4, 4).reshape(2, 2); |         INDArray linspace = Nd4j.linspace(1, 4, 4,DataType.DOUBLE).reshape(2, 2); | ||||||
|         INDArray asseriton = Nd4j.linspace(1, 4, 4); |         INDArray asseriton = Nd4j.linspace(1, 4, 4,DataType.DOUBLE); | ||||||
|         INDArray raveled = linspace.ravel(); |         INDArray raveled = linspace.ravel(); | ||||||
|         assertEquals(asseriton, raveled); |         assertEquals(asseriton, raveled); | ||||||
| 
 | 
 | ||||||
|         INDArray tensorLinSpace = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); |         INDArray tensorLinSpace = Nd4j.linspace(1, 16, 16,DataType.DOUBLE).reshape(2, 2, 2, 2); | ||||||
|         INDArray linspaced = Nd4j.linspace(1, 16, 16); |         INDArray linspaced = Nd4j.linspace(1, 16, 16,DataType.DOUBLE); | ||||||
|         INDArray tensorLinspaceRaveled = tensorLinSpace.ravel(); |         INDArray tensorLinspaceRaveled = tensorLinSpace.ravel(); | ||||||
|         assertEquals(linspaced, tensorLinspaceRaveled); |         assertEquals(linspaced, tensorLinspaceRaveled); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -227,7 +227,7 @@ public class ConcatTestsC extends BaseNd4jTestWithBackends { | |||||||
|         assertEquals(exp, concat2); |         assertEquals(exp, concat2); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|      @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testConcatVector(Nd4jBackend backend) { |     public void testConcatVector(Nd4jBackend backend) { | ||||||
|         assertThrows(ND4JIllegalStateException.class,() -> { |         assertThrows(ND4JIllegalStateException.class,() -> { | ||||||
| @ -236,7 +236,6 @@ public class ConcatTestsC extends BaseNd4jTestWithBackends { | |||||||
|         }); |         }); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Test |  | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testConcat3dv2(Nd4jBackend backend) { |     public void testConcat3dv2(Nd4jBackend backend) { | ||||||
|  | |||||||
| @ -21,9 +21,9 @@ | |||||||
| package org.nd4j.linalg.specials; | package org.nd4j.linalg.specials; | ||||||
| 
 | 
 | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.*; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.parallel.Execution; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.junit.jupiter.api.parallel.Isolated; | import org.junit.jupiter.api.parallel.Isolated; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| @ -49,19 +49,30 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals; | |||||||
| @Slf4j | @Slf4j | ||||||
| @NativeTag | @NativeTag | ||||||
| @Isolated | @Isolated | ||||||
|  | @Execution(ExecutionMode.SAME_THREAD) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
| public class LongTests extends BaseNd4jTestWithBackends { | public class LongTests extends BaseNd4jTestWithBackends { | ||||||
| 
 | 
 | ||||||
|     DataType initialType = Nd4j.dataType(); |     DataType initialType = Nd4j.dataType(); | ||||||
|  |     @BeforeEach | ||||||
|  |     public void beforeEach() { | ||||||
|  |         System.gc(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @AfterEach | ||||||
|  |     public void afterEach() { | ||||||
|  |         System.gc(); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     @Tag(TagNames.LONG_TEST) |     @Tag(TagNames.LONG_TEST) | ||||||
|     public void testSomething1(Nd4jBackend backend) { |     public void testSomething1(Nd4jBackend backend) { | ||||||
|         // we create 2D array, total nr. of elements is 2.4B elements, > MAX_INT |         // we create 2D array, total nr. of elements is 2.4B elements, > MAX_INT | ||||||
|         INDArray huge = Nd4j.create(8000000, 300); |         INDArray huge = Nd4j.create(DataType.INT8,8000000, 300); | ||||||
| 
 | 
 | ||||||
|         // we apply element-wise scalar ops, just to make sure stuff still works |         // we apply element-wise scalar ops, just to make sure stuff still works | ||||||
|         huge.subi(0.5).divi(2); |         huge.subi(1).divi(2); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|         // now we're checking different rows, they should NOT equal |         // now we're checking different rows, they should NOT equal | ||||||
| @ -86,10 +97,10 @@ public class LongTests extends BaseNd4jTestWithBackends { | |||||||
|     @Tag(TagNames.LONG_TEST) |     @Tag(TagNames.LONG_TEST) | ||||||
|     public void testSomething2(Nd4jBackend backend) { |     public void testSomething2(Nd4jBackend backend) { | ||||||
|         // we create 2D array, total nr. of elements is 2.4B elements, > MAX_INT |         // we create 2D array, total nr. of elements is 2.4B elements, > MAX_INT | ||||||
|         INDArray huge = Nd4j.create(100, 10); |         INDArray huge = Nd4j.create(DataType.INT8,100, 10); | ||||||
| 
 | 
 | ||||||
|         // we apply element-wise scalar ops, just to make sure stuff still works |         // we apply element-wise scalar ops, just to make sure stuff still works | ||||||
|         huge.subi(0.5).divi(2); |         huge.subi(1).divi(2); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|         // now we're checking different rows, they should NOT equal |         // now we're checking different rows, they should NOT equal | ||||||
| @ -113,7 +124,7 @@ public class LongTests extends BaseNd4jTestWithBackends { | |||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     @Tag(TagNames.LONG_TEST) |     @Tag(TagNames.LONG_TEST) | ||||||
|     public void testLongTadOffsets1(Nd4jBackend backend) { |     public void testLongTadOffsets1(Nd4jBackend backend) { | ||||||
|         INDArray huge = Nd4j.create(230000000, 10); |         INDArray huge = Nd4j.create(DataType.INT8,230000000, 10); | ||||||
| 
 | 
 | ||||||
|         Pair<DataBuffer, DataBuffer> tad = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(huge, 1); |         Pair<DataBuffer, DataBuffer> tad = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(huge, 1); | ||||||
| 
 | 
 | ||||||
| @ -125,10 +136,10 @@ public class LongTests extends BaseNd4jTestWithBackends { | |||||||
|     @Tag(TagNames.LONG_TEST) |     @Tag(TagNames.LONG_TEST) | ||||||
|     public void testLongTadOp1(Nd4jBackend backend) { |     public void testLongTadOp1(Nd4jBackend backend) { | ||||||
| 
 | 
 | ||||||
|         double exp = Transforms.manhattanDistance(Nd4j.create(1000).assign(1.0), Nd4j.create(1000).assign(2.0)); |         double exp = Transforms.manhattanDistance(Nd4j.create(DataType.INT16,1000).assign(1.0), Nd4j.create(DataType.INT16,1000).assign(2.0)); | ||||||
| 
 | 
 | ||||||
|         INDArray hugeX = Nd4j.create(2200000, 1000).assign(1.0); |         INDArray hugeX = Nd4j.create(DataType.INT16,2200000, 1000).assign(1.0); | ||||||
|         INDArray hugeY = Nd4j.create(1, 1000).assign(2.0); |         INDArray hugeY = Nd4j.create(DataType.INT16,1, 1000).assign(2.0); | ||||||
| 
 | 
 | ||||||
|         for (int x = 0; x < hugeX.rows(); x++) { |         for (int x = 0; x < hugeX.rows(); x++) { | ||||||
|             assertEquals(1000, hugeX.getRow(x).sumNumber().intValue(),"Failed at row " + x); |             assertEquals(1000, hugeX.getRow(x).sumNumber().intValue(),"Failed at row " + x); | ||||||
| @ -144,9 +155,8 @@ public class LongTests extends BaseNd4jTestWithBackends { | |||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     @Tag(TagNames.LONG_TEST) |     @Tag(TagNames.LONG_TEST) | ||||||
|     public void testLongTadOp2(Nd4jBackend backend) { |     public void testLongTadOp2(Nd4jBackend backend) { | ||||||
| 
 |         INDArray hugeX = Nd4j.create(DataType.INT16,2300000, 1000).assign(1.0); | ||||||
|         INDArray hugeX = Nd4j.create(2300000, 1000).assign(1.0); |         hugeX.addiRowVector(Nd4j.create(DataType.INT16,1000).assign(2.0)); | ||||||
|         hugeX.addiRowVector(Nd4j.create(1000).assign(2.0)); |  | ||||||
| 
 | 
 | ||||||
|         for (int x = 0; x < hugeX.rows(); x++) { |         for (int x = 0; x < hugeX.rows(); x++) { | ||||||
|             assertEquals( hugeX.getRow(x).sumNumber().intValue(),3000,"Failed at row " + x); |             assertEquals( hugeX.getRow(x).sumNumber().intValue(),3000,"Failed at row " + x); | ||||||
| @ -158,8 +168,8 @@ public class LongTests extends BaseNd4jTestWithBackends { | |||||||
|     @Tag(TagNames.LONG_TEST) |     @Tag(TagNames.LONG_TEST) | ||||||
|     public void testLongTadOp2_micro(Nd4jBackend backend) { |     public void testLongTadOp2_micro(Nd4jBackend backend) { | ||||||
| 
 | 
 | ||||||
|         INDArray hugeX = Nd4j.create(230, 1000).assign(1.0); |         INDArray hugeX = Nd4j.create(DataType.INT16,230, 1000).assign(1.0); | ||||||
|         hugeX.addiRowVector(Nd4j.create(1000).assign(2.0)); |         hugeX.addiRowVector(Nd4j.create(DataType.INT16,1000).assign(2.0)); | ||||||
| 
 | 
 | ||||||
|         for (int x = 0; x < hugeX.rows(); x++) { |         for (int x = 0; x < hugeX.rows(); x++) { | ||||||
|             assertEquals( 3000, hugeX.getRow(x).sumNumber().intValue(),"Failed at row " + x); |             assertEquals( 3000, hugeX.getRow(x).sumNumber().intValue(),"Failed at row " + x); | ||||||
| @ -171,7 +181,7 @@ public class LongTests extends BaseNd4jTestWithBackends { | |||||||
|     @Tag(TagNames.LONG_TEST) |     @Tag(TagNames.LONG_TEST) | ||||||
|     public void testLongTadOp3(Nd4jBackend backend) { |     public void testLongTadOp3(Nd4jBackend backend) { | ||||||
| 
 | 
 | ||||||
|         INDArray hugeX = Nd4j.create(2300000, 1000).assign(1.0); |         INDArray hugeX = Nd4j.create(DataType.INT16,2300000, 1000).assign(1.0); | ||||||
|         INDArray mean = hugeX.mean(1); |         INDArray mean = hugeX.mean(1); | ||||||
| 
 | 
 | ||||||
|         for (int x = 0; x < hugeX.rows(); x++) { |         for (int x = 0; x < hugeX.rows(); x++) { | ||||||
| @ -184,7 +194,7 @@ public class LongTests extends BaseNd4jTestWithBackends { | |||||||
|     @Tag(TagNames.LONG_TEST) |     @Tag(TagNames.LONG_TEST) | ||||||
|     public void testLongTadOp4(Nd4jBackend backend) { |     public void testLongTadOp4(Nd4jBackend backend) { | ||||||
| 
 | 
 | ||||||
|         INDArray hugeX = Nd4j.create(2300000, 1000).assign(1.0); |         INDArray hugeX = Nd4j.create(DataType.INT8,2300000, 1000).assign(1.0); | ||||||
|         INDArray mean = hugeX.argMax(1); |         INDArray mean = hugeX.argMax(1); | ||||||
| 
 | 
 | ||||||
|         for (int x = 0; x < hugeX.rows(); x++) { |         for (int x = 0; x < hugeX.rows(); x++) { | ||||||
| @ -199,7 +209,7 @@ public class LongTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|         List<INDArray> list = new ArrayList<>(); |         List<INDArray> list = new ArrayList<>(); | ||||||
|         for (int i = 0; i < 2300000; i++) { |         for (int i = 0; i < 2300000; i++) { | ||||||
|             list.add(Nd4j.create(1000).assign(2.0)); |             list.add(Nd4j.create(DataType.INT8,1000).assign(2.0)); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         INDArray hugeX = Nd4j.vstack(list); |         INDArray hugeX = Nd4j.vstack(list); | ||||||
|  | |||||||
| @ -23,6 +23,8 @@ package org.nd4j.linalg.workspace; | |||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import lombok.val; | import lombok.val; | ||||||
| import org.junit.jupiter.api.*; | import org.junit.jupiter.api.*; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| 
 | 
 | ||||||
| @ -54,6 +56,7 @@ import static org.nd4j.linalg.api.buffer.DataType.DOUBLE; | |||||||
| @Slf4j | @Slf4j | ||||||
| @Tag(TagNames.WORKSPACES) | @Tag(TagNames.WORKSPACES) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Execution(ExecutionMode.SAME_THREAD) | ||||||
| public class BasicWorkspaceTests extends BaseNd4jTestWithBackends { | public class BasicWorkspaceTests extends BaseNd4jTestWithBackends { | ||||||
|     DataType initialType = Nd4j.dataType(); |     DataType initialType = Nd4j.dataType(); | ||||||
| 
 | 
 | ||||||
| @ -959,6 +962,7 @@ public class BasicWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Execution(ExecutionMode.SAME_THREAD) | ||||||
|     public void testMmap1(Nd4jBackend backend) { |     public void testMmap1(Nd4jBackend backend) { | ||||||
|         // we don't support MMAP on cuda yet |         // we don't support MMAP on cuda yet | ||||||
|         if (Nd4j.getExecutioner().getClass().getName().toLowerCase().contains("cuda")) |         if (Nd4j.getExecutioner().getClass().getName().toLowerCase().contains("cuda")) | ||||||
| @ -989,12 +993,13 @@ public class BasicWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     @Test |  | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Execution(ExecutionMode.SAME_THREAD) | ||||||
|  |     @Disabled("Still failing even with single thread execution") | ||||||
|     public void testMmap2(Nd4jBackend backend) throws Exception { |     public void testMmap2(Nd4jBackend backend) throws Exception { | ||||||
|         // we don't support MMAP on cuda yet |         // we don't support MMAP on cuda yet | ||||||
|         if (Nd4j.getExecutioner().getClass().getName().toLowerCase().contains("cuda")) |         if (!backend.getEnvironment().isCPU()) | ||||||
|             return; |             return; | ||||||
| 
 | 
 | ||||||
|         File tmp = File.createTempFile("tmp", "fdsfdf"); |         File tmp = File.createTempFile("tmp", "fdsfdf"); | ||||||
|  | |||||||
| @ -20,6 +20,7 @@ | |||||||
| 
 | 
 | ||||||
| package org.nd4j.linalg.workspace; | package org.nd4j.linalg.workspace; | ||||||
| 
 | 
 | ||||||
|  | import lombok.SneakyThrows; | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import lombok.val; | import lombok.val; | ||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.Disabled; | ||||||
| @ -65,8 +66,11 @@ public class CyclicWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     @SneakyThrows | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Tag(TagNames.LONG_TEST) | ||||||
|  |     @Tag(TagNames.LARGE_RESOURCES) | ||||||
|     public void testGc(Nd4jBackend backend) { |     public void testGc(Nd4jBackend backend) { | ||||||
|         val indArray = Nd4j.create(4, 4); |         val indArray = Nd4j.create(4, 4); | ||||||
|         indArray.putRow(0, Nd4j.create(new float[]{0, 2, -2, 0})); |         indArray.putRow(0, Nd4j.create(new float[]{0, 2, -2, 0})); | ||||||
| @ -76,7 +80,7 @@ public class CyclicWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|         for (int i = 0; i < 100000000; i++) { |         for (int i = 0; i < 100000000; i++) { | ||||||
|             indArray.getRow(i % 3); |             indArray.getRow(i % 3); | ||||||
|             //Thread.sleep(1); |             Thread.sleep(1); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach; | |||||||
| import org.junit.jupiter.api.BeforeEach; | import org.junit.jupiter.api.BeforeEach; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| 
 | 
 | ||||||
| @ -48,8 +50,8 @@ import static org.junit.jupiter.api.Assertions.*; | |||||||
| @Slf4j | @Slf4j | ||||||
| @Tag(TagNames.WORKSPACES) | @Tag(TagNames.WORKSPACES) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Execution(ExecutionMode.SAME_THREAD) | ||||||
| public class DebugModeTests extends BaseNd4jTestWithBackends { | public class DebugModeTests extends BaseNd4jTestWithBackends { | ||||||
|     DataType initialType = Nd4j.dataType(); |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach; | |||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.Disabled; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| 
 | 
 | ||||||
| @ -62,12 +64,11 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
|     public void shutUp() { |     public void shutUp() { | ||||||
|         Nd4j.getMemoryManager().setCurrentWorkspace(null); |         Nd4j.getMemoryManager().setCurrentWorkspace(null); | ||||||
|         Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); |         Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); | ||||||
|         Nd4j.setDataType(this.initialType); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     @Disabled |     @Execution(ExecutionMode.SAME_THREAD) | ||||||
|     public void testVariableTimeSeries1(Nd4jBackend backend) { |     public void testVariableTimeSeries1(Nd4jBackend backend) { | ||||||
|         WorkspaceConfiguration configuration = WorkspaceConfiguration |         WorkspaceConfiguration configuration = WorkspaceConfiguration | ||||||
|                 .builder() |                 .builder() | ||||||
| @ -80,28 +81,28 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
|                 .build(); |                 .build(); | ||||||
| 
 | 
 | ||||||
|         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { |         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { | ||||||
|             Nd4j.create(500); |             Nd4j.create(DataType.DOUBLE,500); | ||||||
|             Nd4j.create(500); |             Nd4j.create(DataType.DOUBLE,500); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1"); |         Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1"); | ||||||
| 
 | 
 | ||||||
|         assertEquals(0, workspace.getStepNumber()); |         assertEquals(0, workspace.getStepNumber()); | ||||||
| 
 | 
 | ||||||
|         long requiredMemory = 1000 * Nd4j.sizeOfDataType(); |         long requiredMemory = 1000 * DataType.DOUBLE.width(); | ||||||
|         long shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8)); |         long shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8)); | ||||||
|         assertEquals(requiredMemory, workspace.getSpilledSize()); |         assertEquals(requiredMemory, workspace.getSpilledSize()); | ||||||
|         assertEquals(shiftedSize, workspace.getInitialBlockSize()); |         assertEquals(shiftedSize, workspace.getInitialBlockSize()); | ||||||
|         assertEquals(workspace.getInitialBlockSize() * 4, workspace.getCurrentSize()); |         assertEquals(workspace.getInitialBlockSize() * 4, workspace.getCurrentSize()); | ||||||
| 
 | 
 | ||||||
|         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS1")) { |         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS1")) { | ||||||
|             Nd4j.create(2000); |             Nd4j.create(DataType.DOUBLE,2000); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         assertEquals(0, workspace.getStepNumber()); |         assertEquals(0, workspace.getStepNumber()); | ||||||
| 
 | 
 | ||||||
|         assertEquals(1000 * Nd4j.sizeOfDataType(), workspace.getSpilledSize()); |         assertEquals(1000 * DataType.DOUBLE.width(), workspace.getSpilledSize()); | ||||||
|         assertEquals(2000 * Nd4j.sizeOfDataType(), workspace.getPinnedSize()); |         assertEquals(2000 * DataType.DOUBLE.width(), workspace.getPinnedSize()); | ||||||
| 
 | 
 | ||||||
|         assertEquals(0, workspace.getDeviceOffset()); |         assertEquals(0, workspace.getDeviceOffset()); | ||||||
| 
 | 
 | ||||||
| @ -116,8 +117,8 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
|         for (int e = 0; e < 4; e++) { |         for (int e = 0; e < 4; e++) { | ||||||
|             for (int i = 0; i < 4; i++) { |             for (int i = 0; i < 4; i++) { | ||||||
|                 try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { |                 try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { | ||||||
|                     Nd4j.create(500); |                     Nd4j.create(DataType.DOUBLE,500); | ||||||
|                     Nd4j.create(500); |                     Nd4j.create(DataType.DOUBLE,500); | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|                 assertEquals((i + 1) * workspace.getInitialBlockSize(), |                 assertEquals((i + 1) * workspace.getInitialBlockSize(), | ||||||
| @ -144,9 +145,9 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
|         // we just do huge loop now, with pinned stuff in it |         // we just do huge loop now, with pinned stuff in it | ||||||
|         for (int i = 0; i < 100; i++) { |         for (int i = 0; i < 100; i++) { | ||||||
|             try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { |             try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { | ||||||
|                 Nd4j.create(500); |                 Nd4j.create(DataType.DOUBLE,500); | ||||||
|                 Nd4j.create(500); |                 Nd4j.create(DataType.DOUBLE,500); | ||||||
|                 Nd4j.create(500); |                 Nd4j.create(DataType.DOUBLE,500); | ||||||
| 
 | 
 | ||||||
|                 assertEquals(1500 * Nd4j.sizeOfDataType(), workspace.getThisCycleAllocations()); |                 assertEquals(1500 * Nd4j.sizeOfDataType(), workspace.getThisCycleAllocations()); | ||||||
|             } |             } | ||||||
| @ -160,8 +161,8 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
|         // and we do another clean loo, without pinned stuff in it, to ensure all pinned allocates are gone |         // and we do another clean loo, without pinned stuff in it, to ensure all pinned allocates are gone | ||||||
|         for (int i = 0; i < 100; i++) { |         for (int i = 0; i < 100; i++) { | ||||||
|             try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { |             try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { | ||||||
|                 Nd4j.create(500); |                 Nd4j.create(DataType.DOUBLE,500); | ||||||
|                 Nd4j.create(500); |                 Nd4j.create(DataType.DOUBLE,500); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| @ -186,13 +187,12 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
| //        workspace.enableDebug(true); | //        workspace.enableDebug(true); | ||||||
| 
 | 
 | ||||||
|         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { |         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { | ||||||
|             Nd4j.create(500); |             Nd4j.create(DataType.DOUBLE,500); | ||||||
|             Nd4j.create(500); |             Nd4j.create(DataType.DOUBLE,500); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         assertEquals(0, workspace.getStepNumber()); |         assertEquals(0, workspace.getStepNumber()); | ||||||
| 
 |         long requiredMemory = 1000 * DataType.DOUBLE.width(); | ||||||
|         long requiredMemory = 1000 * Nd4j.sizeOfDataType(); |  | ||||||
|         long shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8)); |         long shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8)); | ||||||
|         assertEquals(requiredMemory, workspace.getSpilledSize()); |         assertEquals(requiredMemory, workspace.getSpilledSize()); | ||||||
|         assertEquals(shiftedSize, workspace.getInitialBlockSize()); |         assertEquals(shiftedSize, workspace.getInitialBlockSize()); | ||||||
| @ -200,9 +200,9 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|         for (int i = 0; i < 100; i++) { |         for (int i = 0; i < 100; i++) { | ||||||
|             try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { |             try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) { | ||||||
|                 Nd4j.create(500); |                 Nd4j.create(DataType.DOUBLE,500); | ||||||
|                 Nd4j.create(500); |                 Nd4j.create(DataType.DOUBLE,500); | ||||||
|                 Nd4j.create(500); |                 Nd4j.create(DataType.DOUBLE,500); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| @ -226,11 +226,11 @@ public class SpecialWorkspaceTests extends BaseNd4jTestWithBackends { | |||||||
|         Nd4jWorkspace workspace = |         Nd4jWorkspace workspace = | ||||||
|                 (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "WS109"); |                 (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "WS109"); | ||||||
| 
 | 
 | ||||||
|         INDArray row = Nd4j.linspace(1, 10, 10); |         INDArray row = Nd4j.linspace(1, 10, 10).castTo(DataType.DOUBLE); | ||||||
|         INDArray exp = Nd4j.create(10).assign(2.0); |         INDArray exp = Nd4j.create(DataType.DOUBLE,10).assign(2.0); | ||||||
|         INDArray result = null; |         INDArray result = null; | ||||||
|         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS109")) { |         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS109")) { | ||||||
|             INDArray matrix = Nd4j.create(10, 10); |             INDArray matrix = Nd4j.create(DataType.DOUBLE,10, 10); | ||||||
|             for (int e = 0; e < matrix.rows(); e++) |             for (int e = 0; e < matrix.rows(); e++) | ||||||
|                 matrix.getRow(e).assign(row); |                 matrix.getRow(e).assign(row); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -26,6 +26,8 @@ import org.junit.jupiter.api.AfterEach; | |||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.Disabled; | ||||||
| import org.junit.jupiter.api.Tag; | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
|  | import org.junit.jupiter.api.parallel.Execution; | ||||||
|  | import org.junit.jupiter.api.parallel.ExecutionMode; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| 
 | 
 | ||||||
| @ -57,6 +59,7 @@ import static org.junit.jupiter.api.Assertions.*; | |||||||
| @Slf4j | @Slf4j | ||||||
| @Tag(TagNames.WORKSPACES) | @Tag(TagNames.WORKSPACES) | ||||||
| @NativeTag | @NativeTag | ||||||
|  | @Execution(ExecutionMode.SAME_THREAD) | ||||||
| public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | ||||||
|     private static final WorkspaceConfiguration basicConfiguration = WorkspaceConfiguration.builder().initialSize(81920) |     private static final WorkspaceConfiguration basicConfiguration = WorkspaceConfiguration.builder().initialSize(81920) | ||||||
|             .overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.NONE) |             .overallocationLimit(0.1).policySpill(SpillPolicy.EXTERNAL).policyLearning(LearningPolicy.NONE) | ||||||
| @ -119,7 +122,6 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | |||||||
|     public void shutUp() { |     public void shutUp() { | ||||||
|         Nd4j.getMemoryManager().setCurrentWorkspace(null); |         Nd4j.getMemoryManager().setCurrentWorkspace(null); | ||||||
|         Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); |         Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); | ||||||
|         Nd4j.setDataType(this.initialType); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
| @ -144,7 +146,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | |||||||
|         for (int x = 0; x < 100; x++) { |         for (int x = 0; x < 100; x++) { | ||||||
|             try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager() |             try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager() | ||||||
|                     .getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) { |                     .getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) { | ||||||
|                 INDArray array = Nd4j.create(100); |                 INDArray array = Nd4j.create(DataType.DOUBLE,100); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // only checking after workspace is initialized |             // only checking after workspace is initialized | ||||||
| @ -174,7 +176,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | |||||||
|             try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager() |             try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager() | ||||||
|                     .getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) { |                     .getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) { | ||||||
| 
 | 
 | ||||||
|                 INDArray array = Nd4j.create(100); |                 INDArray array = Nd4j.create(DataType.DOUBLE,100); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, |             Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, | ||||||
| @ -200,7 +202,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testMultithreading1() throws Exception { |     public void testMultithreading1(Nd4jBackend backend) throws Exception { | ||||||
|         final List<MemoryWorkspace> workspaces = new CopyOnWriteArrayList<>(); |         final List<MemoryWorkspace> workspaces = new CopyOnWriteArrayList<>(); | ||||||
|         Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); |         Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); | ||||||
| 
 | 
 | ||||||
| @ -283,21 +285,23 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|  |     @Disabled | ||||||
|  |     @Tag(TagNames.NEEDS_VERIFY) | ||||||
|     public void testNestedWorkspacesOverlap1(Nd4jBackend backend) { |     public void testNestedWorkspacesOverlap1(Nd4jBackend backend) { | ||||||
|         Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); |  | ||||||
|         Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); |         Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration); | ||||||
|         try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1").notifyScopeEntered()) { |         try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1").notifyScopeEntered()) { | ||||||
|             INDArray array = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); |             INDArray array = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); | ||||||
| 
 | 
 | ||||||
|             long reqMem = 5 * Nd4j.sizeOfDataType(); |             long reqMem = 5 * array.dataType().width(); | ||||||
|             assertEquals(reqMem + (Nd4jWorkspace.alignmentBase - reqMem % Nd4jWorkspace.alignmentBase), ws1.getPrimaryOffset()); |             long add = ((Nd4jWorkspace.alignmentBase / 2) - reqMem % (Nd4jWorkspace.alignmentBase / 2)); | ||||||
|  |             assertEquals(reqMem + add, ws1.getPrimaryOffset()); | ||||||
|             try (Nd4jWorkspace ws2 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS2").notifyScopeEntered()) { |             try (Nd4jWorkspace ws2 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS2").notifyScopeEntered()) { | ||||||
| 
 | 
 | ||||||
|                 INDArray array2 = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); |                 INDArray array2 = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); | ||||||
| 
 | 
 | ||||||
|                 reqMem = 5 * Nd4j.sizeOfDataType(); |                 reqMem = 5 * array2.dataType().width(); | ||||||
|                 assertEquals(reqMem + (Nd4jWorkspace.alignmentBase - reqMem % Nd4jWorkspace.alignmentBase), ws1.getPrimaryOffset()); |                 assertEquals(reqMem + ((Nd4jWorkspace.alignmentBase / 2) - reqMem % (Nd4jWorkspace.alignmentBase / 2)), ws1.getPrimaryOffset()); | ||||||
|                 assertEquals(reqMem + (Nd4jWorkspace.alignmentBase - reqMem % Nd4jWorkspace.alignmentBase), ws2.getPrimaryOffset()); |                 assertEquals(reqMem + ((Nd4jWorkspace.alignmentBase / 2) - reqMem % (Nd4jWorkspace.alignmentBase / 2)), ws2.getPrimaryOffset()); | ||||||
| 
 | 
 | ||||||
|                 try (Nd4jWorkspace ws3 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1") |                 try (Nd4jWorkspace ws3 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1") | ||||||
|                         .notifyScopeBorrowed()) { |                         .notifyScopeBorrowed()) { | ||||||
| @ -305,8 +309,8 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|                     INDArray array3 = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); |                     INDArray array3 = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); | ||||||
| 
 | 
 | ||||||
|                     assertEquals(reqMem + (Nd4jWorkspace.alignmentBase - reqMem % Nd4jWorkspace.alignmentBase), ws2.getPrimaryOffset()); |                     assertEquals(reqMem + ((Nd4jWorkspace.alignmentBase / 2) - reqMem % (Nd4jWorkspace.alignmentBase / 2)), ws2.getPrimaryOffset()); | ||||||
|                     assertEquals((reqMem + (Nd4jWorkspace.alignmentBase - reqMem % Nd4jWorkspace.alignmentBase)) * 2, ws1.getPrimaryOffset()); |                     assertEquals((reqMem + ((Nd4jWorkspace.alignmentBase / 2) - reqMem % (Nd4jWorkspace.alignmentBase / 2))) * 2, ws1.getPrimaryOffset()); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @ -317,7 +321,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | |||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testWorkspacesSerde3() throws Exception { |     public void testWorkspacesSerde3() throws Exception { | ||||||
|         INDArray array = Nd4j.create(10).assign(1.0); |         INDArray array = Nd4j.create(DataType.DOUBLE,10).assign(1.0); | ||||||
|         INDArray restored = null; |         INDArray restored = null; | ||||||
| 
 | 
 | ||||||
|         ByteArrayOutputStream bos = new ByteArrayOutputStream(); |         ByteArrayOutputStream bos = new ByteArrayOutputStream(); | ||||||
| @ -600,7 +604,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | |||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testReallocate1(Nd4jBackend backend) { |     public void testReallocate1(Nd4jBackend backend) { | ||||||
|         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) { |         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) { | ||||||
|             INDArray array = Nd4j.create(100); |             INDArray array = Nd4j.create(DataType.DOUBLE,100); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -612,7 +616,7 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | |||||||
|         assertEquals(100 * Nd4j.sizeOfDataType(), workspace.getCurrentSize()); |         assertEquals(100 * Nd4j.sizeOfDataType(), workspace.getCurrentSize()); | ||||||
| 
 | 
 | ||||||
|         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) { |         try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) { | ||||||
|             INDArray array = Nd4j.create(1000); |             INDArray array = Nd4j.create(DataType.DOUBLE,1000); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         assertEquals(1000 * Nd4j.sizeOfDataType(), workspace.getMaxCycleAllocations()); |         assertEquals(1000 * Nd4j.sizeOfDataType(), workspace.getMaxCycleAllocations()); | ||||||
| @ -634,14 +638,14 @@ public class WorkspaceProviderTests extends BaseNd4jTestWithBackends { | |||||||
|     public void testNestedWorkspaces11(Nd4jBackend backend) { |     public void testNestedWorkspaces11(Nd4jBackend backend) { | ||||||
|         for (int x = 1; x < 10; x++) { |         for (int x = 1; x < 10; x++) { | ||||||
|             try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { |             try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { | ||||||
|                 INDArray array1 = Nd4j.create(100 * x); |                 INDArray array1 = Nd4j.create(DataType.DOUBLE,100 * x); | ||||||
| 
 | 
 | ||||||
|                 for (int i = 1; i < 10; i++) { |                 for (int i = 1; i < 10; i++) { | ||||||
|                     try (MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { |                     try (MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { | ||||||
|                         INDArray array2 = Nd4j.create(100 * x); |                         INDArray array2 = Nd4j.create(DataType.DOUBLE,100 * x); | ||||||
|                         for (int e = 1; e < 10; e++) { |                         for (int e = 1; e < 10; e++) { | ||||||
|                             try (MemoryWorkspace ws3 = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(basicConfiguration, "WS_1").notifyScopeBorrowed()) { |                             try (MemoryWorkspace ws3 = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(basicConfiguration, "WS_1").notifyScopeBorrowed()) { | ||||||
|                                 INDArray array3 = Nd4j.create(100 * x); |                                 INDArray array3 = Nd4j.create(DataType.DOUBLE,100 * x); | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  | |||||||
| @ -55,7 +55,7 @@ public abstract class BaseND4JTest { | |||||||
|      * Override this method to set the default timeout for methods in the test class |      * Override this method to set the default timeout for methods in the test class | ||||||
|      */ |      */ | ||||||
|     public long getTimeoutMilliseconds(){ |     public long getTimeoutMilliseconds(){ | ||||||
|         return 90_000; |         return 180_000; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
| @ -95,7 +95,7 @@ public abstract class BaseND4JTest { | |||||||
|     /** |     /** | ||||||
|      * @return True if integration tests maven profile is enabled, false otherwise. |      * @return True if integration tests maven profile is enabled, false otherwise. | ||||||
|      */ |      */ | ||||||
|     public boolean isIntegrationTests(){ |     public boolean isIntegrationTests() { | ||||||
|         if(integrationTest == null){ |         if(integrationTest == null){ | ||||||
|             String prop = System.getenv("DL4J_INTEGRATION_TESTS"); |             String prop = System.getenv("DL4J_INTEGRATION_TESTS"); | ||||||
|             integrationTest = Boolean.parseBoolean(prop); |             integrationTest = Boolean.parseBoolean(prop); | ||||||
|  | |||||||
| @ -0,0 +1,38 @@ | |||||||
|  | /* | ||||||
|  |  * | ||||||
|  |  *  *  ****************************************************************************** | ||||||
|  |  *  *  * | ||||||
|  |  *  *  * | ||||||
|  |  *  *  * This program and the accompanying materials are made available under the | ||||||
|  |  *  *  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  *  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||||
|  |  *  *  * | ||||||
|  |  *  *  *  See the NOTICE file distributed with this work for additional | ||||||
|  |  *  *  *  information regarding copyright ownership. | ||||||
|  |  *  *  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  *  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  *  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  *  *  * License for the specific language governing permissions and limitations | ||||||
|  |  *  *  * under the License. | ||||||
|  |  *  *  * | ||||||
|  |  *  *  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  *  *  ***************************************************************************** | ||||||
|  |  * | ||||||
|  |  * | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.nd4j.common.tests.tags; | ||||||
|  | 
 | ||||||
|  | import org.junit.jupiter.api.Tag; | ||||||
|  | 
 | ||||||
|  | import java.lang.annotation.ElementType; | ||||||
|  | import java.lang.annotation.Retention; | ||||||
|  | import java.lang.annotation.RetentionPolicy; | ||||||
|  | import java.lang.annotation.Target; | ||||||
|  | 
 | ||||||
|  | @Target({ElementType.TYPE, ElementType.METHOD}) | ||||||
|  | @Retention(RetentionPolicy.RUNTIME) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
|  | @Tag(TagNames.LONG_TEST) | ||||||
|  | public @interface ExpensiveTest { | ||||||
|  | } | ||||||
| @ -50,4 +50,5 @@ public class TagNames { | |||||||
|     public final static String PYTHON = "python"; |     public final static String PYTHON = "python"; | ||||||
|     public final static String LONG_TEST = "long-running-test"; |     public final static String LONG_TEST = "long-running-test"; | ||||||
|     public final static String NEEDS_VERIFY = "needs-verify"; //tests that need verification of issue |     public final static String NEEDS_VERIFY = "needs-verify"; //tests that need verification of issue | ||||||
|  |     public final static String LARGE_RESOURCES = "large-resources"; | ||||||
| } | } | ||||||
|  | |||||||
| @ -106,6 +106,7 @@ | |||||||
|                         <groupId>org.apache.maven.plugins</groupId> |                         <groupId>org.apache.maven.plugins</groupId> | ||||||
|                         <artifactId>maven-surefire-plugin</artifactId> |                         <artifactId>maven-surefire-plugin</artifactId> | ||||||
|                         <configuration> |                         <configuration> | ||||||
|  |                             <forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/> | ||||||
|                             <forkCount>${cpu.core.count}</forkCount> |                             <forkCount>${cpu.core.count}</forkCount> | ||||||
|                             <reuseForks>false</reuseForks> |                             <reuseForks>false</reuseForks> | ||||||
|                             <environmentVariables> |                             <environmentVariables> | ||||||
| @ -116,7 +117,8 @@ | |||||||
|                                 <include>*.java</include> |                                 <include>*.java</include> | ||||||
|                                 <include>**/*.java</include> |                                 <include>**/*.java</include> | ||||||
|                             </includes> |                             </includes> | ||||||
|                             <argLine> -Xmx8g </argLine> |                             <argLine>-Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine> | ||||||
|  | 
 | ||||||
|                         </configuration> |                         </configuration> | ||||||
|                     </plugin> |                     </plugin> | ||||||
|                 </plugins> |                 </plugins> | ||||||
| @ -140,9 +142,11 @@ | |||||||
|                         <groupId>org.apache.maven.plugins</groupId> |                         <groupId>org.apache.maven.plugins</groupId> | ||||||
|                         <artifactId>maven-surefire-plugin</artifactId> |                         <artifactId>maven-surefire-plugin</artifactId> | ||||||
|                         <configuration> |                         <configuration> | ||||||
|  |                             <forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/> | ||||||
|                             <forkCount>${cpu.core.count}</forkCount> |                             <forkCount>${cpu.core.count}</forkCount> | ||||||
|                             <reuseForks>false</reuseForks> |                             <reuseForks>false</reuseForks> | ||||||
|                             <argLine>-Xmx8g</argLine> |                             <argLine>-Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine> | ||||||
|  | 
 | ||||||
|                         </configuration> |                         </configuration> | ||||||
|                     </plugin> |                     </plugin> | ||||||
|                 </plugins> |                 </plugins> | ||||||
|  | |||||||
| @ -98,6 +98,7 @@ | |||||||
|                         <groupId>org.apache.maven.plugins</groupId> |                         <groupId>org.apache.maven.plugins</groupId> | ||||||
|                         <artifactId>maven-surefire-plugin</artifactId> |                         <artifactId>maven-surefire-plugin</artifactId> | ||||||
|                         <configuration> |                         <configuration> | ||||||
|  |                             <forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/> | ||||||
|                             <forkCount>${cpu.core.count}</forkCount> |                             <forkCount>${cpu.core.count}</forkCount> | ||||||
|                             <reuseForks>false</reuseForks> |                             <reuseForks>false</reuseForks> | ||||||
|                             <testSourceDirectory>src/test/java</testSourceDirectory> |                             <testSourceDirectory>src/test/java</testSourceDirectory> | ||||||
| @ -105,7 +106,8 @@ | |||||||
|                                 <include>*.java</include> |                                 <include>*.java</include> | ||||||
|                                 <include>**/*.java</include> |                                 <include>**/*.java</include> | ||||||
|                             </includes> |                             </includes> | ||||||
|                             <argLine> </argLine> |                             <argLine>-Xmx${test.heap.size} -Dorg.bytedeco.javacpp.maxphysicalbytes=${test.offheap.size} -Dorg.bytedeco.javacpp.maxbytes=${test.offheap.size}</argLine> | ||||||
|  | 
 | ||||||
|                         </configuration> |                         </configuration> | ||||||
|                     </plugin> |                     </plugin> | ||||||
|                 </plugins> |                 </plugins> | ||||||
|  | |||||||
| @ -20,35 +20,61 @@ | |||||||
| 
 | 
 | ||||||
| package org.nd4j; | package org.nd4j; | ||||||
| 
 | 
 | ||||||
|  | import com.sun.jna.Platform; | ||||||
| import lombok.AllArgsConstructor; | import lombok.AllArgsConstructor; | ||||||
|  | import lombok.SneakyThrows; | ||||||
|  | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.apache.spark.SparkConf; | import org.apache.spark.SparkConf; | ||||||
| import org.apache.spark.api.java.JavaRDD; | import org.apache.spark.api.java.JavaRDD; | ||||||
| import org.apache.spark.api.java.JavaSparkContext; | import org.apache.spark.api.java.JavaSparkContext; | ||||||
| import org.apache.spark.api.java.function.VoidFunction; | import org.apache.spark.api.java.function.VoidFunction; | ||||||
| import org.apache.spark.broadcast.Broadcast; | import org.apache.spark.broadcast.Broadcast; | ||||||
| import org.apache.spark.serializer.SerializerInstance; | import org.apache.spark.serializer.SerializerInstance; | ||||||
| import org.junit.jupiter.api.AfterEach; | import org.junit.jupiter.api.*; | ||||||
| import org.junit.jupiter.api.BeforeEach; |  | ||||||
| import org.junit.jupiter.api.Disabled; |  | ||||||
| import org.junit.jupiter.api.Test; |  | ||||||
| import org.nd4j.common.primitives.*; | import org.nd4j.common.primitives.*; | ||||||
|  | import org.nd4j.common.resources.Downloader; | ||||||
| import org.nd4j.common.tests.BaseND4JTest; | import org.nd4j.common.tests.BaseND4JTest; | ||||||
|  | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.api.buffer.DataType; | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| import scala.Tuple2; | import scala.Tuple2; | ||||||
| 
 | 
 | ||||||
|  | import java.io.File; | ||||||
|  | import java.net.URI; | ||||||
| import java.nio.ByteBuffer; | import java.nio.ByteBuffer; | ||||||
| import java.util.ArrayList; | import java.util.ArrayList; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| 
 | 
 | ||||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | import static org.junit.jupiter.api.Assertions.assertEquals; | ||||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | import static org.junit.jupiter.api.Assertions.assertTrue; | ||||||
| @Disabled("Ignoring due to flaky nature of tests") | @Slf4j | ||||||
|  | @Tag(TagNames.SPARK) | ||||||
|  | @Tag(TagNames.DIST_SYSTEMS) | ||||||
| public class TestNd4jKryoSerialization extends BaseND4JTest { | public class TestNd4jKryoSerialization extends BaseND4JTest { | ||||||
| 
 | 
 | ||||||
|     private JavaSparkContext sc; |     private JavaSparkContext sc; | ||||||
| 
 | 
 | ||||||
|  |     @BeforeAll | ||||||
|  |     @SneakyThrows | ||||||
|  |     public static void beforeAll() { | ||||||
|  |         if(Platform.isWindows()) { | ||||||
|  |             File hadoopHome = new File(System.getProperty("java.io.tmpdir"),"hadoop-tmp"); | ||||||
|  |             File binDir = new File(hadoopHome,"bin"); | ||||||
|  |             if(!binDir.exists()) | ||||||
|  |                 binDir.mkdirs(); | ||||||
|  |             File outputFile = new File(binDir,"winutils.exe"); | ||||||
|  |             if(!outputFile.exists()) { | ||||||
|  |                 log.info("Fixing spark for windows"); | ||||||
|  |                 Downloader.download("winutils.exe", | ||||||
|  |                         URI.create("https://github.com/cdarlint/winutils/blob/master/hadoop-2.6.5/bin/winutils.exe?raw=true").toURL(), | ||||||
|  |                         outputFile,"db24b404d2331a1bec7443336a5171f1",3); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             System.setProperty("hadoop.home.dir", hadoopHome.getAbsolutePath()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @BeforeEach |     @BeforeEach | ||||||
|     public void before() { |     public void before() { | ||||||
|         SparkConf sparkConf = new SparkConf(); |         SparkConf sparkConf = new SparkConf(); | ||||||
|  | |||||||
| @ -49,6 +49,12 @@ | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|   <dependencies> |   <dependencies> | ||||||
|  |     <dependency> | ||||||
|  |       <groupId>org.nd4j</groupId> | ||||||
|  |       <artifactId>nd4j-common-tests</artifactId> | ||||||
|  |       <version>${project.version}</version> | ||||||
|  |       <scope>test</scope> | ||||||
|  |     </dependency> | ||||||
|     <dependency> |     <dependency> | ||||||
|       <groupId>org.nd4j</groupId> |       <groupId>org.nd4j</groupId> | ||||||
|       <artifactId>samediff-import-api</artifactId> |       <artifactId>samediff-import-api</artifactId> | ||||||
|  | |||||||
| @ -40,6 +40,7 @@ import org.apache.commons.io.FileUtils | |||||||
| import org.junit.jupiter.api.Disabled | import org.junit.jupiter.api.Disabled | ||||||
| import org.junit.jupiter.api.Test | import org.junit.jupiter.api.Test | ||||||
| import org.nd4j.common.resources.Downloader | import org.nd4j.common.resources.Downloader | ||||||
|  | import org.nd4j.common.tests.tags.ExpensiveTest | ||||||
| import org.nd4j.common.util.ArchiveUtils | import org.nd4j.common.util.ArchiveUtils | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray | import org.nd4j.linalg.api.ndarray.INDArray | ||||||
| import org.nd4j.samediff.frameworkimport.onnx.importer.OnnxFrameworkImporter | import org.nd4j.samediff.frameworkimport.onnx.importer.OnnxFrameworkImporter | ||||||
| @ -50,7 +51,7 @@ import java.io.File | |||||||
| import java.net.URI | import java.net.URI | ||||||
| 
 | 
 | ||||||
| data class InputDataset(val dataSetIndex: Int,val inputPaths: List<String>,val outputPaths: List<String>) | data class InputDataset(val dataSetIndex: Int,val inputPaths: List<String>,val outputPaths: List<String>) | ||||||
| @Disabled | @ExpensiveTest | ||||||
| class TestPretrainedModels { | class TestPretrainedModels { | ||||||
| 
 | 
 | ||||||
|     val modelBaseUrl = "https://media.githubusercontent.com/media/onnx/models/master" |     val modelBaseUrl = "https://media.githubusercontent.com/media/onnx/models/master" | ||||||
|  | |||||||
| @ -22,6 +22,7 @@ package org.nd4j.imports.tfgraphs; | |||||||
| 
 | 
 | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.Disabled; | ||||||
|  | import org.junit.jupiter.api.Tag; | ||||||
| import org.junit.jupiter.api.Test; | import org.junit.jupiter.api.Test; | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
| @ -33,6 +34,7 @@ import org.nd4j.autodiff.samediff.transform.OpPredicate; | |||||||
| import org.nd4j.autodiff.samediff.transform.SubGraph; | import org.nd4j.autodiff.samediff.transform.SubGraph; | ||||||
| import org.nd4j.autodiff.samediff.transform.SubGraphPredicate; | import org.nd4j.autodiff.samediff.transform.SubGraphPredicate; | ||||||
| import org.nd4j.autodiff.samediff.transform.SubGraphProcessor; | import org.nd4j.autodiff.samediff.transform.SubGraphProcessor; | ||||||
|  | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.graph.ui.LogFileWriter; | import org.nd4j.graph.ui.LogFileWriter; | ||||||
| import org.nd4j.imports.graphmapper.tf.TFGraphMapper; | import org.nd4j.imports.graphmapper.tf.TFGraphMapper; | ||||||
| import org.nd4j.imports.tensorflow.TFImportOverride; | import org.nd4j.imports.tensorflow.TFImportOverride; | ||||||
| @ -55,7 +57,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | |||||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | import static org.junit.jupiter.api.Assertions.assertTrue; | ||||||
| 
 | 
 | ||||||
| @Slf4j | @Slf4j | ||||||
| @Disabled("AB 2019/05/21 - JVM Crash on linux-x86_64-cuda-9.2, linux-ppc64le-cpu - Issue #7657") | @Tag(TagNames.LONG_TEST) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
| public class BERTGraphTest extends BaseNd4jTestWithBackends { | public class BERTGraphTest extends BaseNd4jTestWithBackends { | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -44,7 +44,7 @@ public class CustomOpTests extends BaseNd4jTestWithBackends { | |||||||
| 
 | 
 | ||||||
|     @ParameterizedTest |     @ParameterizedTest | ||||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") |     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||||
|     public void testPad(Nd4jBackend backend){ |     public void testPad(Nd4jBackend backend) { | ||||||
| 
 | 
 | ||||||
|         INDArray in = Nd4j.create(DataType.FLOAT, 1, 28, 28, 264); |         INDArray in = Nd4j.create(DataType.FLOAT, 1, 28, 28, 264); | ||||||
|         INDArray pad = Nd4j.createFromArray(new int[][]{{0,0},{0,1},{0,1},{0,0}}); |         INDArray pad = Nd4j.createFromArray(new int[][]{{0,0},{0,1},{0,1},{0,0}}); | ||||||
|  | |||||||
| @ -27,6 +27,7 @@ import org.junit.jupiter.api.BeforeEach; | |||||||
| import org.junit.jupiter.api.Disabled; | import org.junit.jupiter.api.Disabled; | ||||||
| import org.junit.jupiter.params.provider.Arguments; | import org.junit.jupiter.params.provider.Arguments; | ||||||
| 
 | 
 | ||||||
|  | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.api.buffer.DataType; | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.executioner.OpExecutioner; | import org.nd4j.linalg.api.ops.executioner.OpExecutioner; | ||||||
| @ -41,7 +42,8 @@ import java.util.stream.Stream; | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @Slf4j | @Slf4j | ||||||
| @Disabled("AB 2019/05/21 - JVM Crashes - Issue #7657") | @Tag(TagNames.LONG_TEST) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
| public class TFGraphTestAllLibnd4j {   //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests | public class TFGraphTestAllLibnd4j {   //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests | ||||||
| 
 | 
 | ||||||
|     private Map<String, INDArray> inputs; |     private Map<String, INDArray> inputs; | ||||||
|  | |||||||
| @ -26,6 +26,7 @@ import org.junit.jupiter.api.*; | |||||||
| 
 | 
 | ||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.Arguments; | import org.junit.jupiter.params.provider.Arguments; | ||||||
|  | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.api.buffer.DataType; | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.executioner.OpExecutioner; | import org.nd4j.linalg.api.ops.executioner.OpExecutioner; | ||||||
| @ -38,6 +39,8 @@ import java.util.*; | |||||||
| import java.util.stream.Stream; | import java.util.stream.Stream; | ||||||
| 
 | 
 | ||||||
| @Slf4j | @Slf4j | ||||||
|  | @Tag(TagNames.LONG_TEST) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
| public class TFGraphTestAllSameDiff {   //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests | public class TFGraphTestAllSameDiff {   //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -29,6 +29,7 @@ import org.junit.jupiter.api.io.TempDir; | |||||||
| import org.junit.jupiter.params.ParameterizedTest; | import org.junit.jupiter.params.ParameterizedTest; | ||||||
| import org.junit.jupiter.params.provider.Arguments; | import org.junit.jupiter.params.provider.Arguments; | ||||||
| import org.junit.jupiter.params.provider.MethodSource; | import org.junit.jupiter.params.provider.MethodSource; | ||||||
|  | import org.nd4j.common.tests.tags.TagNames; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.executioner.OpExecutioner; | import org.nd4j.linalg.api.ops.executioner.OpExecutioner; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| @ -44,7 +45,8 @@ import java.util.Map; | |||||||
| import java.util.stream.Stream; | import java.util.stream.Stream; | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @Disabled | @Tag(TagNames.LONG_TEST) | ||||||
|  | @Tag(TagNames.LARGE_RESOURCES) | ||||||
| public class TFGraphTestList { | public class TFGraphTestList { | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user