diff --git a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java index 46daaa5f5..b74df2d2c 100644 --- a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java +++ b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java @@ -68,7 +68,7 @@ public abstract class BaseDL4JTest { * Override this method to set the default timeout for methods in the test class */ public long getTimeoutMilliseconds(){ - return 60_000; + return 90_000; } /** diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java index 1815dff73..58587615d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java @@ -24,17 +24,22 @@ import org.junit.rules.Timeout; import java.io.File; import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; /** * @author saudet */ public class SvhnDataFetcherTest extends BaseDL4JTest { - @Rule - public Timeout timeout = Timeout.seconds(600); + @Override + public long getTimeoutMilliseconds() { + return 480_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time. + } @Test public void testSvhnDataFetcher() throws Exception { + assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access + SvhnDataFetcher fetch = new SvhnDataFetcher(); File path = fetch.getDataSetPath(DataSetType.TRAIN); File path2 = fetch.getDataSetPath(DataSetType.TEST); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java index 5c5f9e385..131930623 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java @@ -183,11 +183,11 @@ public class TestCheckpointListener extends BaseDL4JTest { CheckpointListener l = new CheckpointListener.Builder(f) .keepLast(3) - .saveEvery(4, TimeUnit.SECONDS) + .saveEvery(4900, TimeUnit.MILLISECONDS) .build(); net.setListeners(l); - for(int i=0; i<5; i++ ){ //10 iterations total + for(int i=0; i<3; i++ ){ //10 iterations total net.fit(iter); Thread.sleep(5000); } @@ -211,9 +211,10 @@ public class TestCheckpointListener extends BaseDL4JTest { ns.add(n.getIterationCount()); } - assertEquals(3, l.availableCheckpoints().size()); - assertEquals(ns.toString(), 3, ns.size()); - assertTrue(ns.containsAll(Arrays.asList(4,6,8))); + assertEquals(2, l.availableCheckpoints().size()); + assertEquals(ns.toString(), 2, ns.size()); + System.out.println(ns); + assertTrue(ns.containsAll(Arrays.asList(2,4))); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java index 8cd72e770..cac30a7e4 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java @@ -70,6 +70,11 @@ public class TestListeners extends BaseDL4JTest { @Rule public TemporaryFolder tempDir = new TemporaryFolder(); + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + @Test public void testSettingListenersUnsupervised() { //Pretrain layers should get copies of the listeners, in addition to the diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 54acb31d7..2091babb0 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -767,6 +767,13 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura if (!isInitCalled()) init(); + if (solver == null) { + try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + .build(); + } + } + solver.getOptimizer().setGradientsAccumulator(accumulator); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java index ab034604e..68a012b72 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java @@ -75,7 +75,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { @Override public long getTimeoutMilliseconds() { - return 90000L; + return 180000L; } @Test diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/BaseND4JTest.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/BaseND4JTest.java index 54bad9876..eceec6216 100644 --- a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/BaseND4JTest.java +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/BaseND4JTest.java @@ -55,7 +55,7 @@ public abstract class BaseND4JTest { * Override this method to set the default timeout for methods in the test class */ public long getTimeoutMilliseconds(){ - return 60_000; + return 90_000; } /**