Assorted fixes (#466)

* Timeouts and fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Increase default timeout to 90s due to slow PPC CI machines

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another timeout tweak

Signed-off-by: Alex Black <blacka101@gmail.com>

* Svhn

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-15 15:34:08 +10:00 committed by GitHub
parent 753ce28a92
commit deb87b04f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 28 additions and 10 deletions

View File

@ -68,7 +68,7 @@ public abstract class BaseDL4JTest {
* Override this method to set the default timeout for methods in the test class
*/
public long getTimeoutMilliseconds(){
return 60_000;
return 90_000;
}
/**

View File

@ -24,17 +24,22 @@ import org.junit.rules.Timeout;
import java.io.File;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeTrue;
/**
* @author saudet
*/
public class SvhnDataFetcherTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(600);
@Override
public long getTimeoutMilliseconds() {
return 480_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time.
}
@Test
public void testSvhnDataFetcher() throws Exception {
assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access
SvhnDataFetcher fetch = new SvhnDataFetcher();
File path = fetch.getDataSetPath(DataSetType.TRAIN);
File path2 = fetch.getDataSetPath(DataSetType.TEST);

View File

@ -183,11 +183,11 @@ public class TestCheckpointListener extends BaseDL4JTest {
CheckpointListener l = new CheckpointListener.Builder(f)
.keepLast(3)
.saveEvery(4, TimeUnit.SECONDS)
.saveEvery(4900, TimeUnit.MILLISECONDS)
.build();
net.setListeners(l);
for(int i=0; i<5; i++ ){ //10 iterations total
for(int i=0; i<3; i++ ){ //10 iterations total
net.fit(iter);
Thread.sleep(5000);
}
@ -211,9 +211,10 @@ public class TestCheckpointListener extends BaseDL4JTest {
ns.add(n.getIterationCount());
}
assertEquals(3, l.availableCheckpoints().size());
assertEquals(ns.toString(), 3, ns.size());
assertTrue(ns.containsAll(Arrays.asList(4,6,8)));
assertEquals(2, l.availableCheckpoints().size());
assertEquals(ns.toString(), 2, ns.size());
System.out.println(ns);
assertTrue(ns.containsAll(Arrays.asList(2,4)));
}
@Test

View File

@ -70,6 +70,11 @@ public class TestListeners extends BaseDL4JTest {
@Rule
public TemporaryFolder tempDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testSettingListenersUnsupervised() {
//Pretrain layers should get copies of the listeners, in addition to the

View File

@ -767,6 +767,13 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
if (!isInitCalled())
init();
if (solver == null) {
try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this)
.build();
}
}
solver.getOptimizer().setGradientsAccumulator(accumulator);
}

View File

@ -75,7 +75,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
return 180000L;
}
@Test

View File

@ -55,7 +55,7 @@ public abstract class BaseND4JTest {
* Override this method to set the default timeout for methods in the test class
*/
public long getTimeoutMilliseconds(){
return 60_000;
return 90_000;
}
/**