Assorted fixes (#445)

* #8890 TestTransferStatsCollection timeout fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8890 RegressionTest timeout fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Increase timeout for flaky test

Signed-off-by: Alex Black <blacka101@gmail.com>

* Remove test dependency

Signed-off-by: Alex Black <blacka101@gmail.com>

* BLAS iamax + loop -> argmax

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-11 15:57:00 +10:00 committed by GitHub
parent 57d16653c8
commit 91a2004d8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 52 additions and 43 deletions

View File

@ -56,8 +56,10 @@ import static org.junit.Assert.*;
*/
public class RegressionTest050 extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
public DataType getDataType(){

View File

@ -64,6 +64,11 @@ public class RegressionTest060 extends BaseDL4JTest {
return DataType.FLOAT;
}
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void regressionTestMLP1() throws Exception {

View File

@ -64,6 +64,12 @@ public class RegressionTest071 extends BaseDL4JTest {
public DataType getDataType(){
return DataType.FLOAT;
}
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void regressionTestMLP1() throws Exception {

View File

@ -64,6 +64,11 @@ public class RegressionTest080 extends BaseDL4JTest {
return DataType.FLOAT;
}
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void regressionTestMLP1() throws Exception {

View File

@ -56,7 +56,7 @@ public class RegressionTest100a extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override

View File

@ -52,7 +52,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override

View File

@ -71,7 +71,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override

View File

@ -58,7 +58,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test

View File

@ -30,6 +30,11 @@ import static org.junit.Assert.assertTrue;
*/
public class TestDistributionDeserializer extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void testDistributionDeserializer() throws Exception {
//Test current format:

View File

@ -46,6 +46,11 @@ public class TestDeepWalk extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 120_000L; //Increase timeout due to intermittently slow CI machines
}
@Test(timeout = 60000L)
public void testBasic() throws IOException {
//Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions

View File

@ -84,12 +84,6 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<version>4.0.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -968,7 +968,7 @@ public class WordVectorSerializer {
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
// first we load syn0
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors));
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors)); //Note stream is closed in loadTxt
InMemoryLookupTable lookupTable = pair.getFirst();
lookupTable.setNegative(configuration.getNegative());
if (configuration.getNegative() > 0)
@ -1607,7 +1607,7 @@ public class WordVectorSerializer {
*/
@Deprecated
public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
FileInputStream fileInputStream = new FileInputStream(vectorsFile);
FileInputStream fileInputStream = new FileInputStream(vectorsFile); //Note stream is closed in loadTxt
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(fileInputStream);
return fromPair(pair);
}

View File

@ -49,8 +49,8 @@ import java.io.File;
import java.util.Collection;
import java.util.concurrent.Callable;
import static org.awaitility.Awaitility.await;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j
@ -206,12 +206,6 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
await()
.until(new Callable<Boolean>() {
@Override
public Boolean call() {
return net.params().equalsWithEps(restored.params(), 2e-3);
}
});
assertTrue(net.params().equalsWithEps(restored.params(), 2e-3));
}
}

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -247,10 +248,8 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
@Override
public int[] predict(INDArray input) {
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
int[] ret = new int[input.rows()];
for (int i = 0; i < ret.length; i++)
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
return ret;
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
return output.argMax(1).toIntVector();
}
/**

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
@ -251,10 +252,8 @@ public class LossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.LossL
@Override
public int[] predict(INDArray input) {
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
int[] ret = new int[input.rows()];
for (int i = 0; i < ret.length; i++)
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
return ret;
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
return output.argMax(1).toIntVector();
}
/**

View File

@ -2220,14 +2220,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
if (d.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
int[] ret = new int[(int) d.size(0)];
if (d.isRowVectorOrScalar())
ret[0] = Nd4j.getBlasWrapper().iamax(output);
else {
for (int i = 0; i < ret.length; i++)
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
}
return ret;
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
return output.argMax(1).toIntVector();
}
/**

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
@ -42,8 +43,10 @@ import java.io.IOException;
*/
public class TestTransferStatsCollection extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Override
public long getTimeoutMilliseconds() {
return 90_000L;
}
@Test
public void test() throws IOException {
@ -62,9 +65,7 @@ public class TestTransferStatsCollection extends BaseDL4JTest {
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
.setFeatureExtractor(0).build();
File dir = testDir.newFolder();
File f = new File(dir, "dl4jTestTransferStatsCollection.bin");
net2.setListeners(new StatsListener(new FileStatsStorage(f)));
net2.setListeners(new StatsListener(new InMemoryStatsStorage()));
//Previosuly: failed on frozen layers
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));