More test fixes

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2022-10-07 14:59:54 +02:00
parent a7f75fe6db
commit 9660ab026d
39 changed files with 8 additions and 183 deletions

View File

@ -16,4 +16,6 @@ dependencies {
implementation 'org.apache.commons:commons-math3' implementation 'org.apache.commons:commons-math3'
implementation 'org.apache.commons:commons-lang3' implementation 'org.apache.commons:commons-lang3'
implementation 'org.apache.commons:commons-compress' implementation 'org.apache.commons:commons-compress'
testRuntimeOnly 'net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT'
} }

View File

@ -27,11 +27,6 @@ import org.junit.jupiter.api.Test;
public class TestDataSets extends BaseDL4JTest { public class TestDataSets extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L;
}
@Test @Test
public void testTinyImageNetExists() throws Exception { public void testTinyImageNetExists() throws Exception {
//Simple sanity check on extracting //Simple sanity check on extracting

View File

@ -35,11 +35,6 @@ import static org.junit.jupiter.api.Assumptions.assumeTrue;
*/ */
public class SvhnDataFetcherTest extends BaseDL4JTest { public class SvhnDataFetcherTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 480_000_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time.
}
@Test @Test
public void testSvhnDataFetcher() throws Exception { public void testSvhnDataFetcher() throws Exception {
assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access

View File

@ -59,11 +59,6 @@ import static org.junit.jupiter.api.Assertions.*;
public class DataSetIteratorTest extends BaseDL4JTest { public class DataSetIteratorTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 360000; //Should run quickly; increased to large timeout due to occasonal slow CI downloads
}
@Test @Test
public void testBatchSizeOfOneIris() throws Exception { public void testBatchSizeOfOneIris() throws Exception {
//Test for (a) iterators returning correct number of examples, and //Test for (a) iterators returning correct number of examples, and

View File

@ -51,10 +51,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
////@Ignore ////@Ignore
public class AttentionLayerTest extends BaseDL4JTest { public class AttentionLayerTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testSelfAttentionLayer() { public void testSelfAttentionLayer() {

View File

@ -62,11 +62,6 @@ public class BNGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testGradient2dSimple() { public void testGradient2dSimple() {
DataNormalization scaler = new NormalizerMinMaxScaler(); DataNormalization scaler = new NormalizerMinMaxScaler();

View File

@ -62,11 +62,6 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 180000;
}
@Test @Test
public void testCnn1DWithLocallyConnected1D() { public void testCnn1DWithLocallyConnected1D() {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);

View File

@ -59,11 +59,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testCnn3DPlain() { public void testCnn3DPlain() {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);

View File

@ -73,11 +73,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
return CNN2DFormat.values(); return CNN2DFormat.values();
} }
@Override
public long getTimeoutMilliseconds() {
return 999990000L;
}
@Test @Test
public void testGradientCNNMLN() { public void testGradientCNNMLN() {
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format... if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...

View File

@ -49,11 +49,6 @@ import java.util.Random;
////@Ignore ////@Ignore
public class CapsnetGradientCheckTest extends BaseDL4JTest { public class CapsnetGradientCheckTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testCapsNet() { public void testCapsNet() {

View File

@ -59,11 +59,6 @@ public class DropoutGradientCheck extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testDropoutGradient() { public void testDropoutGradient() {
int minibatch = 3; int minibatch = 3;

View File

@ -55,11 +55,6 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
private static final double DEFAULT_MAX_REL_ERROR = 1e-3; private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testRNNGlobalPoolingBasicMultiLayer() { public void testRNNGlobalPoolingBasicMultiLayer() {
//Basic test of global pooling w/ LSTM //Basic test of global pooling w/ LSTM

View File

@ -70,11 +70,6 @@ public class GradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testMinibatchApplication() { public void testMinibatchApplication() {
IrisDataSetIterator iter = new IrisDataSetIterator(30, 150); IrisDataSetIterator iter = new IrisDataSetIterator(30, 150);

View File

@ -71,11 +71,6 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 999999999L;
}
@Test @Test
public void testBasicIris() { public void testBasicIris() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -59,11 +59,6 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
private static class GradientCheckSimpleScenario { private static class GradientCheckSimpleScenario {
private final ILossFunction lf; private final ILossFunction lf;
private final Activation act; private final Activation act;

View File

@ -54,12 +54,6 @@ public class LRNGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testGradientLRNSimple() { public void testGradientLRNSimple() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -55,11 +55,6 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testLSTMBasicMultiLayer() { public void testLSTMBasicMultiLayer() {
//Basic test of GravesLSTM layer //Basic test of GravesLSTM layer

View File

@ -73,11 +73,6 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
private static final double DEFAULT_MAX_REL_ERROR = 1e-5; private static final double DEFAULT_MAX_REL_ERROR = 1e-5;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void lossFunctionGradientCheck() { public void lossFunctionGradientCheck() {
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(), ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(),

View File

@ -52,11 +52,6 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testGradientNoBiasDenseOutput() { public void testGradientNoBiasDenseOutput() {

View File

@ -52,11 +52,6 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testRnnLossLayer() { public void testRnnLossLayer() {
Nd4j.getRandom().setSeed(12345L); Nd4j.getRandom().setSeed(12345L);

View File

@ -55,11 +55,6 @@ public class RnnGradientChecks extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
////@Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") ////@Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
public void testBidirectionalWrapper() { public void testBidirectionalWrapper() {

View File

@ -56,11 +56,6 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testMaskLayer() { public void testMaskLayer() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -57,11 +57,6 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testVaeAsMLP() { public void testVaeAsMLP() {
//Post pre-training: a VAE can be used as a MLP, by taking the mean value from p(z|x) as the output //Post pre-training: a VAE can be used as a MLP, by taking the mean value from p(z|x) as the output

View File

@ -72,11 +72,6 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
@TempDir @TempDir
public File testDir; public File testDir;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testYoloOutputLayer() { public void testYoloOutputLayer() {
int depthIn = 2; int depthIn = 2;

View File

@ -186,11 +186,6 @@ public class DTypeTests extends BaseDL4JTest {
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
)); ));
@Override
public long getTimeoutMilliseconds() {
return 9999999L;
}
@AfterAll @AfterAll
public static void after() { public static void after() {
ImmutableSet<ClassPath.ClassInfo> info; ImmutableSet<ClassPath.ClassInfo> info;

View File

@ -93,11 +93,6 @@ public class BatchNormalizationTest extends BaseDL4JTest {
public void doBefore() { public void doBefore() {
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testDnnForwardPass() { public void testDnnForwardPass() {
int nOut = 10; int nOut = 10;

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.multilayer; package org.deeplearning4j.nn.multilayer;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
@ -1424,6 +1425,7 @@ public class MultiLayerTest extends BaseDL4JTest {
} }
@Data @Data
@EqualsAndHashCode(callSuper = false)
public static class CheckModelsListener extends BaseTrainingListener { public static class CheckModelsListener extends BaseTrainingListener {
private Set<Class<?>> modelClasses = new HashSet<>(); private Set<Class<?>> modelClasses = new HashSet<>();

View File

@ -39,11 +39,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 1200000L;
}
/** /**
* This test ensures, that memory amount assigned to buffer is enough for any number of updates * This test ensures, that memory amount assigned to buffer is enough for any number of updates
* @throws Exception * @throws Exception

View File

@ -47,11 +47,6 @@ import static org.junit.jupiter.api.Assertions.*;
public class TestCheckpointListener extends BaseDL4JTest { public class TestCheckpointListener extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@TempDir @TempDir
public File tempDir; public File tempDir;

View File

@ -67,11 +67,6 @@ public class TestListeners extends BaseDL4JTest {
@TempDir @TempDir
public File tempDir; public File tempDir;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testSettingListenersUnsupervised() { public void testSettingListenersUnsupervised() {
//Pretrain layers should get copies of the listeners, in addition to the //Pretrain layers should get copies of the listeners, in addition to the

View File

@ -60,11 +60,6 @@ public class RegressionTest060 extends BaseDL4JTest {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void regressionTestMLP1() throws Exception { public void regressionTestMLP1() throws Exception {

View File

@ -61,11 +61,6 @@ public class RegressionTest071 extends BaseDL4JTest {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void regressionTestMLP1() throws Exception { public void regressionTestMLP1() throws Exception {

View File

@ -60,11 +60,6 @@ public class RegressionTest080 extends BaseDL4JTest {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void regressionTestMLP1() throws Exception { public void regressionTestMLP1() throws Exception {

View File

@ -57,11 +57,6 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class RegressionTest100a extends BaseDL4JTest { public class RegressionTest100a extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override @Override
public DataType getDataType(){ public DataType getDataType(){
return DataType.FLOAT; return DataType.FLOAT;

View File

@ -54,11 +54,6 @@ import static org.junit.jupiter.api.Assertions.*;
public class RegressionTest100b3 extends BaseDL4JTest { public class RegressionTest100b3 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override @Override
public DataType getDataType(){ public DataType getDataType(){
return DataType.FLOAT; return DataType.FLOAT;

View File

@ -73,11 +73,6 @@ import org.nd4j.common.resources.Resources;
public class RegressionTest100b4 extends BaseDL4JTest { public class RegressionTest100b4 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override @Override
public DataType getDataType() { public DataType getDataType() {
return DataType.FLOAT; return DataType.FLOAT;

View File

@ -60,11 +60,6 @@ public class RegressionTest100b6 extends BaseDL4JTest {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void testCustomLayer() throws Exception { public void testCustomLayer() throws Exception {

View File

@ -31,11 +31,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestDistributionDeserializer extends BaseDL4JTest { public class TestDistributionDeserializer extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void testDistributionDeserializer() throws Exception { public void testDistributionDeserializer() throws Exception {
//Test current format: //Test current format:

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.cpu.nativecpu; package org.nd4j.linalg.cpu.nativecpu;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
@ -578,8 +579,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
* @return the concatenate ndarrays * @return the concatenate ndarrays
*/ */
@Override @Override
public INDArray concat(int dimension, INDArray... toConcat) { public INDArray concat(int dimension, @NonNull INDArray... toConcat) {
if (toConcat == null || toConcat.length == 0) if (toConcat.length == 0)
throw new ND4JIllegalStateException("Can't concatenate 0 arrays"); throw new ND4JIllegalStateException("Can't concatenate 0 arrays");
if (toConcat.length == 1) if (toConcat.length == 1)