Assorted fixes (#445)
* #8890 TestTransferStatsCollection timeout fix Signed-off-by: Alex Black <blacka101@gmail.com> * #8890 RegressionTest timeout fix Signed-off-by: Alex Black <blacka101@gmail.com> * Increase timeout for flaky test Signed-off-by: Alex Black <blacka101@gmail.com> * Remove test dependency Signed-off-by: Alex Black <blacka101@gmail.com> * BLAS iamax + loop -> argmax Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
57d16653c8
commit
91a2004d8f
|
@ -56,8 +56,10 @@ import static org.junit.Assert.*;
|
||||||
*/
|
*/
|
||||||
public class RegressionTest050 extends BaseDL4JTest {
|
public class RegressionTest050 extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
@Override
|
||||||
public Timeout timeout = Timeout.seconds(300);
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
|
|
|
@ -64,6 +64,11 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void regressionTestMLP1() throws Exception {
|
public void regressionTestMLP1() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,12 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void regressionTestMLP1() throws Exception {
|
public void regressionTestMLP1() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,11 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void regressionTestMLP1() throws Exception {
|
public void regressionTestMLP1() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ public class RegressionTest100a extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -52,7 +52,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -71,7 +71,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -58,7 +58,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -30,6 +30,11 @@ import static org.junit.Assert.assertTrue;
|
||||||
*/
|
*/
|
||||||
public class TestDistributionDeserializer extends BaseDL4JTest {
|
public class TestDistributionDeserializer extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDistributionDeserializer() throws Exception {
|
public void testDistributionDeserializer() throws Exception {
|
||||||
//Test current format:
|
//Test current format:
|
||||||
|
|
|
@ -46,6 +46,11 @@ public class TestDeepWalk extends BaseDL4JTest {
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 120_000L; //Increase timeout due to intermittently slow CI machines
|
||||||
|
}
|
||||||
|
|
||||||
@Test(timeout = 60000L)
|
@Test(timeout = 60000L)
|
||||||
public void testBasic() throws IOException {
|
public void testBasic() throws IOException {
|
||||||
//Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions
|
//Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions
|
||||||
|
|
|
@ -84,12 +84,6 @@
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.awaitility</groupId>
|
|
||||||
<artifactId>awaitility</artifactId>
|
|
||||||
<version>4.0.2</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -968,7 +968,7 @@ public class WordVectorSerializer {
|
||||||
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
|
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
|
||||||
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
|
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
|
||||||
// first we load syn0
|
// first we load syn0
|
||||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors));
|
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors)); //Note stream is closed in loadTxt
|
||||||
InMemoryLookupTable lookupTable = pair.getFirst();
|
InMemoryLookupTable lookupTable = pair.getFirst();
|
||||||
lookupTable.setNegative(configuration.getNegative());
|
lookupTable.setNegative(configuration.getNegative());
|
||||||
if (configuration.getNegative() > 0)
|
if (configuration.getNegative() > 0)
|
||||||
|
@ -1607,7 +1607,7 @@ public class WordVectorSerializer {
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
|
public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
|
||||||
FileInputStream fileInputStream = new FileInputStream(vectorsFile);
|
FileInputStream fileInputStream = new FileInputStream(vectorsFile); //Note stream is closed in loadTxt
|
||||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(fileInputStream);
|
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(fileInputStream);
|
||||||
return fromPair(pair);
|
return fromPair(pair);
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,8 +49,8 @@ import java.io.File;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.concurrent.Callable;
|
import java.util.concurrent.Callable;
|
||||||
|
|
||||||
import static org.awaitility.Awaitility.await;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@ -206,12 +206,6 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
|
||||||
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
||||||
|
|
||||||
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
||||||
await()
|
assertTrue(net.params().equalsWithEps(restored.params(), 2e-3));
|
||||||
.until(new Callable<Boolean>() {
|
|
||||||
@Override
|
|
||||||
public Boolean call() {
|
|
||||||
return net.params().equalsWithEps(restored.params(), 2e-3);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.optimize.Solver;
|
import org.deeplearning4j.optimize.Solver;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -247,10 +248,8 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
|
||||||
@Override
|
@Override
|
||||||
public int[] predict(INDArray input) {
|
public int[] predict(INDArray input) {
|
||||||
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
|
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
|
||||||
int[] ret = new int[input.rows()];
|
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
|
||||||
for (int i = 0; i < ret.length; i++)
|
return output.argMax(1).toIntVector();
|
||||||
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.optimize.Solver;
|
import org.deeplearning4j.optimize.Solver;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
@ -251,10 +252,8 @@ public class LossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.LossL
|
||||||
@Override
|
@Override
|
||||||
public int[] predict(INDArray input) {
|
public int[] predict(INDArray input) {
|
||||||
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
|
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
|
||||||
int[] ret = new int[input.rows()];
|
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
|
||||||
for (int i = 0; i < ret.length; i++)
|
return output.argMax(1).toIntVector();
|
||||||
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -2220,14 +2220,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
if (d.size(0) > Integer.MAX_VALUE)
|
if (d.size(0) > Integer.MAX_VALUE)
|
||||||
throw new ND4JArraySizeException();
|
throw new ND4JArraySizeException();
|
||||||
|
|
||||||
int[] ret = new int[(int) d.size(0)];
|
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
|
||||||
if (d.isRowVectorOrScalar())
|
return output.argMax(1).toIntVector();
|
||||||
ret[0] = Nd4j.getBlasWrapper().iamax(output);
|
|
||||||
else {
|
|
||||||
for (int i = 0; i < ret.length; i++)
|
|
||||||
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.ui.model.stats.StatsListener;
|
import org.deeplearning4j.ui.model.stats.StatsListener;
|
||||||
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
|
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
|
||||||
|
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
@ -42,8 +43,10 @@ import java.io.IOException;
|
||||||
*/
|
*/
|
||||||
public class TestTransferStatsCollection extends BaseDL4JTest {
|
public class TestTransferStatsCollection extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
@Override
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90_000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test() throws IOException {
|
public void test() throws IOException {
|
||||||
|
@ -62,9 +65,7 @@ public class TestTransferStatsCollection extends BaseDL4JTest {
|
||||||
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
|
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
|
||||||
.setFeatureExtractor(0).build();
|
.setFeatureExtractor(0).build();
|
||||||
|
|
||||||
File dir = testDir.newFolder();
|
net2.setListeners(new StatsListener(new InMemoryStatsStorage()));
|
||||||
File f = new File(dir, "dl4jTestTransferStatsCollection.bin");
|
|
||||||
net2.setListeners(new StatsListener(new FileStatsStorage(f)));
|
|
||||||
|
|
||||||
//Previosuly: failed on frozen layers
|
//Previosuly: failed on frozen layers
|
||||||
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));
|
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));
|
||||||
|
|
Loading…
Reference in New Issue