diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java
index da2671817..a5408f100 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java
@@ -56,8 +56,10 @@ import static org.junit.Assert.*;
*/
public class RegressionTest050 extends BaseDL4JTest {
- @Rule
- public Timeout timeout = Timeout.seconds(300);
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ }
@Override
public DataType getDataType(){
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java
index 0f6005884..10a15919d 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java
@@ -64,6 +64,11 @@ public class RegressionTest060 extends BaseDL4JTest {
return DataType.FLOAT;
}
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ }
+
@Test
public void regressionTestMLP1() throws Exception {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java
index 76027ca59..3d34b3fd2 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java
@@ -64,6 +64,12 @@ public class RegressionTest071 extends BaseDL4JTest {
public DataType getDataType(){
return DataType.FLOAT;
}
+
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ }
+
@Test
public void regressionTestMLP1() throws Exception {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java
index b4aa72712..f963c2f59 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java
@@ -64,6 +64,11 @@ public class RegressionTest080 extends BaseDL4JTest {
return DataType.FLOAT;
}
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ }
+
@Test
public void regressionTestMLP1() throws Exception {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java
index dda910b0e..3214a80ef 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java
@@ -56,7 +56,7 @@ public class RegressionTest100a extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java
index a96d4cc30..3a1bafd95 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java
@@ -52,7 +52,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java
index 2be2970a7..49dd8f34a 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java
@@ -71,7 +71,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java
index 34bca6cc2..a87a522c7 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java
@@ -58,7 +58,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
- return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java
index 530d8f3c6..9bee792b3 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java
@@ -30,6 +30,11 @@ import static org.junit.Assert.assertTrue;
*/
public class TestDistributionDeserializer extends BaseDL4JTest {
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
+ }
+
@Test
public void testDistributionDeserializer() throws Exception {
//Test current format:
diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java
index 91939430f..d589c9a8a 100644
--- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java
+++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java
@@ -46,6 +46,11 @@ public class TestDeepWalk extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 120_000L; //Increase timeout due to intermittently slow CI machines
+ }
+
@Test(timeout = 60000L)
public void testBasic() throws IOException {
//Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
index 8a7eacada..668c728ae 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
@@ -84,12 +84,6 @@
${project.version}
test
-
- org.awaitility
- awaitility
- 4.0.2
- test
-
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
index 7c4eb6783..3598bb62a 100755
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
@@ -968,7 +968,7 @@ public class WordVectorSerializer {
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
// first we load syn0
- Pair pair = loadTxt(new FileInputStream(vectors));
+ Pair pair = loadTxt(new FileInputStream(vectors)); //Note stream is closed in loadTxt
InMemoryLookupTable lookupTable = pair.getFirst();
lookupTable.setNegative(configuration.getNegative());
if (configuration.getNegative() > 0)
@@ -1607,7 +1607,7 @@ public class WordVectorSerializer {
*/
@Deprecated
public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
- FileInputStream fileInputStream = new FileInputStream(vectorsFile);
+ FileInputStream fileInputStream = new FileInputStream(vectorsFile); //Note stream is closed in loadTxt
Pair pair = loadTxt(fileInputStream);
return fromPair(pair);
}
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java
index 38b44d1ff..b8b30c6c9 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java
@@ -49,8 +49,8 @@ import java.io.File;
import java.util.Collection;
import java.util.concurrent.Callable;
-import static org.awaitility.Awaitility.await;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
@Slf4j
@@ -206,12 +206,6 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
- await()
- .until(new Callable() {
- @Override
- public Boolean call() {
- return net.params().equalsWithEps(restored.params(), 2e-3);
- }
- });
+ assertTrue(net.params().equalsWithEps(restored.params(), 2e-3));
}
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java
index b5c68c884..a10bb33f3 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java
@@ -25,6 +25,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
+import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -247,10 +248,8 @@ public abstract class BaseOutputLayer Integer.MAX_VALUE)
throw new ND4JArraySizeException();
- int[] ret = new int[(int) d.size(0)];
- if (d.isRowVectorOrScalar())
- ret[0] = Nd4j.getBlasWrapper().iamax(output);
- else {
- for (int i = 0; i < ret.length; i++)
- ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
- }
- return ret;
+ Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
+ return output.argMax(1).toIntVector();
}
/**
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java
index 90b84265d..6c1c5f1d4 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java
@@ -26,6 +26,7 @@ import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
+import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
@@ -42,8 +43,10 @@ import java.io.IOException;
*/
public class TestTransferStatsCollection extends BaseDL4JTest {
- @Rule
- public TemporaryFolder testDir = new TemporaryFolder();
+ @Override
+ public long getTimeoutMilliseconds() {
+ return 90_000L;
+ }
@Test
public void test() throws IOException {
@@ -62,9 +65,7 @@ public class TestTransferStatsCollection extends BaseDL4JTest {
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
.setFeatureExtractor(0).build();
- File dir = testDir.newFolder();
- File f = new File(dir, "dl4jTestTransferStatsCollection.bin");
- net2.setListeners(new StatsListener(new FileStatsStorage(f)));
+ net2.setListeners(new StatsListener(new InMemoryStatsStorage()));
//Previosuly: failed on frozen layers
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));