parent
a7f75fe6db
commit
9660ab026d
|
@ -16,4 +16,6 @@ dependencies {
|
||||||
implementation 'org.apache.commons:commons-math3'
|
implementation 'org.apache.commons:commons-math3'
|
||||||
implementation 'org.apache.commons:commons-lang3'
|
implementation 'org.apache.commons:commons-lang3'
|
||||||
implementation 'org.apache.commons:commons-compress'
|
implementation 'org.apache.commons:commons-compress'
|
||||||
}
|
|
||||||
|
testRuntimeOnly 'net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT'
|
||||||
|
}
|
||||||
|
|
|
@ -27,11 +27,6 @@ import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
public class TestDataSets extends BaseDL4JTest {
|
public class TestDataSets extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTinyImageNetExists() throws Exception {
|
public void testTinyImageNetExists() throws Exception {
|
||||||
//Simple sanity check on extracting
|
//Simple sanity check on extracting
|
||||||
|
|
|
@ -35,11 +35,6 @@ import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
||||||
*/
|
*/
|
||||||
public class SvhnDataFetcherTest extends BaseDL4JTest {
|
public class SvhnDataFetcherTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 480_000_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time.
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSvhnDataFetcher() throws Exception {
|
public void testSvhnDataFetcher() throws Exception {
|
||||||
assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access
|
assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access
|
||||||
|
|
|
@ -59,11 +59,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class DataSetIteratorTest extends BaseDL4JTest {
|
public class DataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 360000; //Should run quickly; increased to large timeout due to occasonal slow CI downloads
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBatchSizeOfOneIris() throws Exception {
|
public void testBatchSizeOfOneIris() throws Exception {
|
||||||
//Test for (a) iterators returning correct number of examples, and
|
//Test for (a) iterators returning correct number of examples, and
|
||||||
|
|
|
@ -51,10 +51,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
////@Ignore
|
////@Ignore
|
||||||
public class AttentionLayerTest extends BaseDL4JTest {
|
public class AttentionLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSelfAttentionLayer() {
|
public void testSelfAttentionLayer() {
|
||||||
|
|
|
@ -62,11 +62,6 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGradient2dSimple() {
|
public void testGradient2dSimple() {
|
||||||
DataNormalization scaler = new NormalizerMinMaxScaler();
|
DataNormalization scaler = new NormalizerMinMaxScaler();
|
||||||
|
|
|
@ -62,11 +62,6 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCnn1DWithLocallyConnected1D() {
|
public void testCnn1DWithLocallyConnected1D() {
|
||||||
Nd4j.getRandom().setSeed(1337);
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
|
@ -59,11 +59,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCnn3DPlain() {
|
public void testCnn3DPlain() {
|
||||||
Nd4j.getRandom().setSeed(1337);
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
|
@ -73,11 +73,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
return CNN2DFormat.values();
|
return CNN2DFormat.values();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 999990000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGradientCNNMLN() {
|
public void testGradientCNNMLN() {
|
||||||
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
|
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
|
||||||
|
|
|
@ -49,11 +49,6 @@ import java.util.Random;
|
||||||
////@Ignore
|
////@Ignore
|
||||||
public class CapsnetGradientCheckTest extends BaseDL4JTest {
|
public class CapsnetGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCapsNet() {
|
public void testCapsNet() {
|
||||||
|
|
||||||
|
|
|
@ -59,11 +59,6 @@ public class DropoutGradientCheck extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDropoutGradient() {
|
public void testDropoutGradient() {
|
||||||
int minibatch = 3;
|
int minibatch = 3;
|
||||||
|
|
|
@ -55,11 +55,6 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
|
||||||
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
|
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
|
||||||
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRNNGlobalPoolingBasicMultiLayer() {
|
public void testRNNGlobalPoolingBasicMultiLayer() {
|
||||||
//Basic test of global pooling w/ LSTM
|
//Basic test of global pooling w/ LSTM
|
||||||
|
|
|
@ -70,11 +70,6 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMinibatchApplication() {
|
public void testMinibatchApplication() {
|
||||||
IrisDataSetIterator iter = new IrisDataSetIterator(30, 150);
|
IrisDataSetIterator iter = new IrisDataSetIterator(30, 150);
|
||||||
|
|
|
@ -71,11 +71,6 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 999999999L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicIris() {
|
public void testBasicIris() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -59,11 +59,6 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class GradientCheckSimpleScenario {
|
private static class GradientCheckSimpleScenario {
|
||||||
private final ILossFunction lf;
|
private final ILossFunction lf;
|
||||||
private final Activation act;
|
private final Activation act;
|
||||||
|
|
|
@ -54,12 +54,6 @@ public class LRNGradientCheckTests extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGradientLRNSimple() {
|
public void testGradientLRNSimple() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -55,11 +55,6 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLSTMBasicMultiLayer() {
|
public void testLSTMBasicMultiLayer() {
|
||||||
//Basic test of GravesLSTM layer
|
//Basic test of GravesLSTM layer
|
||||||
|
|
|
@ -73,11 +73,6 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
private static final double DEFAULT_MAX_REL_ERROR = 1e-5;
|
private static final double DEFAULT_MAX_REL_ERROR = 1e-5;
|
||||||
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void lossFunctionGradientCheck() {
|
public void lossFunctionGradientCheck() {
|
||||||
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(),
|
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(),
|
||||||
|
|
|
@ -52,11 +52,6 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGradientNoBiasDenseOutput() {
|
public void testGradientNoBiasDenseOutput() {
|
||||||
|
|
||||||
|
|
|
@ -52,11 +52,6 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRnnLossLayer() {
|
public void testRnnLossLayer() {
|
||||||
Nd4j.getRandom().setSeed(12345L);
|
Nd4j.getRandom().setSeed(12345L);
|
||||||
|
|
|
@ -55,11 +55,6 @@ public class RnnGradientChecks extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
////@Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
|
////@Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
|
||||||
public void testBidirectionalWrapper() {
|
public void testBidirectionalWrapper() {
|
||||||
|
|
|
@ -56,11 +56,6 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMaskLayer() {
|
public void testMaskLayer() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -57,11 +57,6 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVaeAsMLP() {
|
public void testVaeAsMLP() {
|
||||||
//Post pre-training: a VAE can be used as a MLP, by taking the mean value from p(z|x) as the output
|
//Post pre-training: a VAE can be used as a MLP, by taking the mean value from p(z|x) as the output
|
||||||
|
|
|
@ -72,11 +72,6 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||||
@TempDir
|
@TempDir
|
||||||
public File testDir;
|
public File testDir;
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testYoloOutputLayer() {
|
public void testYoloOutputLayer() {
|
||||||
int depthIn = 2;
|
int depthIn = 2;
|
||||||
|
|
|
@ -186,11 +186,6 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
||||||
));
|
));
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 9999999L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterAll
|
@AfterAll
|
||||||
public static void after() {
|
public static void after() {
|
||||||
ImmutableSet<ClassPath.ClassInfo> info;
|
ImmutableSet<ClassPath.ClassInfo> info;
|
||||||
|
|
|
@ -93,11 +93,6 @@ public class BatchNormalizationTest extends BaseDL4JTest {
|
||||||
public void doBefore() {
|
public void doBefore() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDnnForwardPass() {
|
public void testDnnForwardPass() {
|
||||||
int nOut = 10;
|
int nOut = 10;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.multilayer;
|
package org.deeplearning4j.nn.multilayer;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
|
@ -1424,6 +1425,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public static class CheckModelsListener extends BaseTrainingListener {
|
public static class CheckModelsListener extends BaseTrainingListener {
|
||||||
|
|
||||||
private Set<Class<?>> modelClasses = new HashSet<>();
|
private Set<Class<?>> modelClasses = new HashSet<>();
|
||||||
|
|
|
@ -39,11 +39,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
|
public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 1200000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This test ensures, that memory amount assigned to buffer is enough for any number of updates
|
* This test ensures, that memory amount assigned to buffer is enough for any number of updates
|
||||||
* @throws Exception
|
* @throws Exception
|
||||||
|
|
|
@ -47,11 +47,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestCheckpointListener extends BaseDL4JTest {
|
public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@TempDir
|
@TempDir
|
||||||
public File tempDir;
|
public File tempDir;
|
||||||
|
|
||||||
|
|
|
@ -67,11 +67,6 @@ public class TestListeners extends BaseDL4JTest {
|
||||||
@TempDir
|
@TempDir
|
||||||
public File tempDir;
|
public File tempDir;
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSettingListenersUnsupervised() {
|
public void testSettingListenersUnsupervised() {
|
||||||
//Pretrain layers should get copies of the listeners, in addition to the
|
//Pretrain layers should get copies of the listeners, in addition to the
|
||||||
|
|
|
@ -60,11 +60,6 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void regressionTestMLP1() throws Exception {
|
public void regressionTestMLP1() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -61,11 +61,6 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void regressionTestMLP1() throws Exception {
|
public void regressionTestMLP1() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -60,11 +60,6 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void regressionTestMLP1() throws Exception {
|
public void regressionTestMLP1() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -57,11 +57,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class RegressionTest100a extends BaseDL4JTest {
|
public class RegressionTest100a extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -54,11 +54,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class RegressionTest100b3 extends BaseDL4JTest {
|
public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -73,11 +73,6 @@ import org.nd4j.common.resources.Resources;
|
||||||
|
|
||||||
public class RegressionTest100b4 extends BaseDL4JTest {
|
public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType() {
|
public DataType getDataType() {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -60,11 +60,6 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCustomLayer() throws Exception {
|
public void testCustomLayer() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -31,11 +31,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestDistributionDeserializer extends BaseDL4JTest {
|
public class TestDistributionDeserializer extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDistributionDeserializer() throws Exception {
|
public void testDistributionDeserializer() throws Exception {
|
||||||
//Test current format:
|
//Test current format:
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.cpu.nativecpu;
|
package org.nd4j.linalg.cpu.nativecpu;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -578,8 +579,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
* @return the concatenate ndarrays
|
* @return the concatenate ndarrays
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public INDArray concat(int dimension, INDArray... toConcat) {
|
public INDArray concat(int dimension, @NonNull INDArray... toConcat) {
|
||||||
if (toConcat == null || toConcat.length == 0)
|
if (toConcat.length == 0)
|
||||||
throw new ND4JIllegalStateException("Can't concatenate 0 arrays");
|
throw new ND4JIllegalStateException("Can't concatenate 0 arrays");
|
||||||
|
|
||||||
if (toConcat.length == 1)
|
if (toConcat.length == 1)
|
||||||
|
|
Loading…
Reference in New Issue