Assorted fixes (#445)

* #8890 TestTransferStatsCollection timeout fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8890 RegressionTest timeout fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Increase timeout for flaky test

Signed-off-by: Alex Black <blacka101@gmail.com>

* Remove test dependency

Signed-off-by: Alex Black <blacka101@gmail.com>

* BLAS iamax + loop -> argmax

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-11 15:57:00 +10:00 committed by GitHub
parent 57d16653c8
commit 91a2004d8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 52 additions and 43 deletions

View File

@ -56,8 +56,10 @@ import static org.junit.Assert.*;
*/ */
public class RegressionTest050 extends BaseDL4JTest { public class RegressionTest050 extends BaseDL4JTest {
@Rule @Override
public Timeout timeout = Timeout.seconds(300); public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override @Override
public DataType getDataType(){ public DataType getDataType(){

View File

@ -64,6 +64,11 @@ public class RegressionTest060 extends BaseDL4JTest {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void regressionTestMLP1() throws Exception { public void regressionTestMLP1() throws Exception {

View File

@ -64,6 +64,12 @@ public class RegressionTest071 extends BaseDL4JTest {
public DataType getDataType(){ public DataType getDataType(){
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void regressionTestMLP1() throws Exception { public void regressionTestMLP1() throws Exception {

View File

@ -64,6 +64,11 @@ public class RegressionTest080 extends BaseDL4JTest {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void regressionTestMLP1() throws Exception { public void regressionTestMLP1() throws Exception {

View File

@ -56,7 +56,7 @@ public class RegressionTest100a extends BaseDL4JTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
} }
@Override @Override

View File

@ -52,7 +52,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
} }
@Override @Override

View File

@ -71,7 +71,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
} }
@Override @Override

View File

@ -58,7 +58,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
} }
@Test @Test

View File

@ -30,6 +30,11 @@ import static org.junit.Assert.assertTrue;
*/ */
public class TestDistributionDeserializer extends BaseDL4JTest { public class TestDistributionDeserializer extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void testDistributionDeserializer() throws Exception { public void testDistributionDeserializer() throws Exception {
//Test current format: //Test current format:

View File

@ -46,6 +46,11 @@ public class TestDeepWalk extends BaseDL4JTest {
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 120_000L; //Increase timeout due to intermittently slow CI machines
}
@Test(timeout = 60000L) @Test(timeout = 60000L)
public void testBasic() throws IOException { public void testBasic() throws IOException {
//Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions //Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions

View File

@ -84,12 +84,6 @@
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<version>4.0.2</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -968,7 +968,7 @@ public class WordVectorSerializer {
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes, public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException { @NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
// first we load syn0 // first we load syn0
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors)); Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors)); //Note stream is closed in loadTxt
InMemoryLookupTable lookupTable = pair.getFirst(); InMemoryLookupTable lookupTable = pair.getFirst();
lookupTable.setNegative(configuration.getNegative()); lookupTable.setNegative(configuration.getNegative());
if (configuration.getNegative() > 0) if (configuration.getNegative() > 0)
@ -1607,7 +1607,7 @@ public class WordVectorSerializer {
*/ */
@Deprecated @Deprecated
public static WordVectors loadTxtVectors(File vectorsFile) throws IOException { public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
FileInputStream fileInputStream = new FileInputStream(vectorsFile); FileInputStream fileInputStream = new FileInputStream(vectorsFile); //Note stream is closed in loadTxt
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(fileInputStream); Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(fileInputStream);
return fromPair(pair); return fromPair(pair);
} }

View File

@ -49,8 +49,8 @@ import java.io.File;
import java.util.Collection; import java.util.Collection;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import static org.awaitility.Awaitility.await;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j @Slf4j
@ -206,12 +206,6 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
await() assertTrue(net.params().equalsWithEps(restored.params(), 2e-3));
.until(new Callable<Boolean>() {
@Override
public Boolean call() {
return net.params().equalsWithEps(restored.params(), 2e-3);
}
});
} }
} }

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver; import org.deeplearning4j.optimize.Solver;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -247,10 +248,8 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
@Override @Override
public int[] predict(INDArray input) { public int[] predict(INDArray input) {
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable()); INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
int[] ret = new int[input.rows()]; Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
for (int i = 0; i < ret.length; i++) return output.argMax(1).toIntVector();
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
return ret;
} }
/** /**

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.Solver; import org.deeplearning4j.optimize.Solver;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
@ -251,10 +252,8 @@ public class LossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.LossL
@Override @Override
public int[] predict(INDArray input) { public int[] predict(INDArray input) {
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable()); INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
int[] ret = new int[input.rows()]; Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
for (int i = 0; i < ret.length; i++) return output.argMax(1).toIntVector();
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
return ret;
} }
/** /**

View File

@ -2220,14 +2220,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
if (d.size(0) > Integer.MAX_VALUE) if (d.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException(); throw new ND4JArraySizeException();
int[] ret = new int[(int) d.size(0)]; Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
if (d.isRowVectorOrScalar()) return output.argMax(1).toIntVector();
ret[0] = Nd4j.getBlasWrapper().iamax(output);
else {
for (int i = 0; i < ret.length; i++)
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
}
return ret;
} }
/** /**

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage; import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
@ -42,8 +43,10 @@ import java.io.IOException;
*/ */
public class TestTransferStatsCollection extends BaseDL4JTest { public class TestTransferStatsCollection extends BaseDL4JTest {
@Rule @Override
public TemporaryFolder testDir = new TemporaryFolder(); public long getTimeoutMilliseconds() {
return 90_000L;
}
@Test @Test
public void test() throws IOException { public void test() throws IOException {
@ -62,9 +65,7 @@ public class TestTransferStatsCollection extends BaseDL4JTest {
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build()) new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
.setFeatureExtractor(0).build(); .setFeatureExtractor(0).build();
File dir = testDir.newFolder(); net2.setListeners(new StatsListener(new InMemoryStatsStorage()));
File f = new File(dir, "dl4jTestTransferStatsCollection.bin");
net2.setListeners(new StatsListener(new FileStatsStorage(f)));
//Previosuly: failed on frozen layers //Previosuly: failed on frozen layers
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10))); net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));