diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java index da2671817..a5408f100 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java @@ -56,8 +56,10 @@ import static org.junit.Assert.*; */ public class RegressionTest050 extends BaseDL4JTest { - @Rule - public Timeout timeout = Timeout.seconds(300); + @Override + public long getTimeoutMilliseconds() { + return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } @Override public DataType getDataType(){ diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java index 0f6005884..10a15919d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java @@ -64,6 +64,11 @@ public class RegressionTest060 extends BaseDL4JTest { return DataType.FLOAT; } + @Override + public long getTimeoutMilliseconds() { + return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Test public void regressionTestMLP1() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java index 76027ca59..3d34b3fd2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java @@ -64,6 +64,12 @@ public class RegressionTest071 extends BaseDL4JTest { public DataType getDataType(){ return DataType.FLOAT; } + + @Override + public long getTimeoutMilliseconds() { + return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Test public void regressionTestMLP1() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java index b4aa72712..f963c2f59 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java @@ -64,6 +64,11 @@ public class RegressionTest080 extends BaseDL4JTest { return DataType.FLOAT; } + @Override + public long getTimeoutMilliseconds() { + return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Test public void regressionTestMLP1() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index dda910b0e..3214a80ef 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -56,7 +56,7 @@ public class RegressionTest100a extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections } @Override diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java index a96d4cc30..3a1bafd95 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java @@ -52,7 +52,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections } @Override diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java index 2be2970a7..49dd8f34a 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -71,7 +71,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections } @Override diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java index 34bca6cc2..a87a522c7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java @@ -58,7 +58,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections + return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java index 530d8f3c6..9bee792b3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java @@ -30,6 +30,11 @@ import static org.junit.Assert.assertTrue; */ public class TestDistributionDeserializer extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections + } + @Test public void testDistributionDeserializer() throws Exception { //Test current format: diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java index 91939430f..d589c9a8a 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java @@ -46,6 +46,11 @@ public class TestDeepWalk extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Override + public long getTimeoutMilliseconds() { + return 120_000L; //Increase timeout due to intermittently slow CI machines + } + @Test(timeout = 60000L) public void testBasic() throws IOException { //Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml index 8a7eacada..668c728ae 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml @@ -84,12 +84,6 @@ ${project.version} test - - org.awaitility - awaitility - 4.0.2 - test - diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java index 7c4eb6783..3598bb62a 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java @@ -968,7 +968,7 @@ public class WordVectorSerializer { public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes, @NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException { // first we load syn0 - Pair pair = loadTxt(new FileInputStream(vectors)); + Pair pair = loadTxt(new FileInputStream(vectors)); //Note stream is closed in loadTxt InMemoryLookupTable lookupTable = pair.getFirst(); lookupTable.setNegative(configuration.getNegative()); if (configuration.getNegative() > 0) @@ -1607,7 +1607,7 @@ public class WordVectorSerializer { */ @Deprecated public static WordVectors loadTxtVectors(File vectorsFile) throws IOException { - FileInputStream fileInputStream = new FileInputStream(vectorsFile); + FileInputStream fileInputStream = new FileInputStream(vectorsFile); //Note stream is closed in loadTxt Pair pair = loadTxt(fileInputStream); return fromPair(pair); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java index 38b44d1ff..b8b30c6c9 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java @@ -49,8 +49,8 @@ import java.io.File; import java.util.Collection; import java.util.concurrent.Callable; -import static org.awaitility.Awaitility.await; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; @Slf4j @@ -206,12 +206,6 @@ public class Word2VecTestsSmall extends BaseDL4JTest { final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); - await() - .until(new Callable() { - @Override - public Boolean call() { - return net.params().equalsWithEps(restored.params(), 2e-3); - } - }); + assertTrue(net.params().equalsWithEps(restored.params(), 2e-3)); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java index b5c68c884..a10bb33f3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.Solver; +import org.nd4j.common.base.Preconditions; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -247,10 +248,8 @@ public abstract class BaseOutputLayer Integer.MAX_VALUE) throw new ND4JArraySizeException(); - int[] ret = new int[(int) d.size(0)]; - if (d.isRowVectorOrScalar()) - ret[0] = Nd4j.getBlasWrapper().iamax(output); - else { - for (int i = 0; i < ret.length; i++) - ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i)); - } - return ret; + Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank()); + return output.argMax(1).toIntVector(); } /** diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java index 90b84265d..6c1c5f1d4 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.storage.FileStatsStorage; +import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -42,8 +43,10 @@ import java.io.IOException; */ public class TestTransferStatsCollection extends BaseDL4JTest { - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + @Override + public long getTimeoutMilliseconds() { + return 90_000L; + } @Test public void test() throws IOException { @@ -62,9 +65,7 @@ public class TestTransferStatsCollection extends BaseDL4JTest { new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build()) .setFeatureExtractor(0).build(); - File dir = testDir.newFolder(); - File f = new File(dir, "dl4jTestTransferStatsCollection.bin"); - net2.setListeners(new StatsListener(new FileStatsStorage(f))); + net2.setListeners(new StatsListener(new InMemoryStatsStorage())); //Previosuly: failed on frozen layers net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));