Assorted fixes (#466)

* Timeouts and fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Increase default timeout to 90s due to slow PPC CI machines

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another timeout tweak

Signed-off-by: Alex Black <blacka101@gmail.com>

* Svhn

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-15 15:34:08 +10:00 committed by GitHub
parent 753ce28a92
commit deb87b04f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 28 additions and 10 deletions

View File

@ -68,7 +68,7 @@ public abstract class BaseDL4JTest {
* Override this method to set the default timeout for methods in the test class * Override this method to set the default timeout for methods in the test class
*/ */
public long getTimeoutMilliseconds(){ public long getTimeoutMilliseconds(){
return 60_000; return 90_000;
} }
/** /**

View File

@ -24,17 +24,22 @@ import org.junit.rules.Timeout;
import java.io.File; import java.io.File;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeTrue;
/** /**
* @author saudet * @author saudet
*/ */
public class SvhnDataFetcherTest extends BaseDL4JTest { public class SvhnDataFetcherTest extends BaseDL4JTest {
@Rule @Override
public Timeout timeout = Timeout.seconds(600); public long getTimeoutMilliseconds() {
return 480_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time.
}
@Test @Test
public void testSvhnDataFetcher() throws Exception { public void testSvhnDataFetcher() throws Exception {
assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access
SvhnDataFetcher fetch = new SvhnDataFetcher(); SvhnDataFetcher fetch = new SvhnDataFetcher();
File path = fetch.getDataSetPath(DataSetType.TRAIN); File path = fetch.getDataSetPath(DataSetType.TRAIN);
File path2 = fetch.getDataSetPath(DataSetType.TEST); File path2 = fetch.getDataSetPath(DataSetType.TEST);

View File

@ -183,11 +183,11 @@ public class TestCheckpointListener extends BaseDL4JTest {
CheckpointListener l = new CheckpointListener.Builder(f) CheckpointListener l = new CheckpointListener.Builder(f)
.keepLast(3) .keepLast(3)
.saveEvery(4, TimeUnit.SECONDS) .saveEvery(4900, TimeUnit.MILLISECONDS)
.build(); .build();
net.setListeners(l); net.setListeners(l);
for(int i=0; i<5; i++ ){ //10 iterations total for(int i=0; i<3; i++ ){ //10 iterations total
net.fit(iter); net.fit(iter);
Thread.sleep(5000); Thread.sleep(5000);
} }
@ -211,9 +211,10 @@ public class TestCheckpointListener extends BaseDL4JTest {
ns.add(n.getIterationCount()); ns.add(n.getIterationCount());
} }
assertEquals(3, l.availableCheckpoints().size()); assertEquals(2, l.availableCheckpoints().size());
assertEquals(ns.toString(), 3, ns.size()); assertEquals(ns.toString(), 2, ns.size());
assertTrue(ns.containsAll(Arrays.asList(4,6,8))); System.out.println(ns);
assertTrue(ns.containsAll(Arrays.asList(2,4)));
} }
@Test @Test

View File

@ -70,6 +70,11 @@ public class TestListeners extends BaseDL4JTest {
@Rule @Rule
public TemporaryFolder tempDir = new TemporaryFolder(); public TemporaryFolder tempDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testSettingListenersUnsupervised() { public void testSettingListenersUnsupervised() {
//Pretrain layers should get copies of the listeners, in addition to the //Pretrain layers should get copies of the listeners, in addition to the

View File

@ -767,6 +767,13 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
if (!isInitCalled()) if (!isInitCalled())
init(); init();
if (solver == null) {
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this)
.build();
}
}
solver.getOptimizer().setGradientsAccumulator(accumulator); solver.getOptimizer().setGradientsAccumulator(accumulator);
} }

View File

@ -75,7 +75,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 90000L; return 180000L;
} }
@Test @Test

View File

@ -55,7 +55,7 @@ public abstract class BaseND4JTest {
* Override this method to set the default timeout for methods in the test class * Override this method to set the default timeout for methods in the test class
*/ */
public long getTimeoutMilliseconds(){ public long getTimeoutMilliseconds(){
return 60_000; return 90_000;
} }
/** /**