Assorted fixes (#445)
* #8890 TestTransferStatsCollection timeout fix Signed-off-by: Alex Black <blacka101@gmail.com> * #8890 RegressionTest timeout fix Signed-off-by: Alex Black <blacka101@gmail.com> * Increase timeout for flaky test Signed-off-by: Alex Black <blacka101@gmail.com> * Remove test dependency Signed-off-by: Alex Black <blacka101@gmail.com> * BLAS iamax + loop -> argmax Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
57d16653c8
commit
91a2004d8f
|
@ -56,8 +56,10 @@ import static org.junit.Assert.*;
|
|||
*/
|
||||
public class RegressionTest050 extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType(){
|
||||
|
|
|
@ -64,6 +64,11 @@ public class RegressionTest060 extends BaseDL4JTest {
|
|||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Test
|
||||
public void regressionTestMLP1() throws Exception {
|
||||
|
||||
|
|
|
@ -64,6 +64,12 @@ public class RegressionTest071 extends BaseDL4JTest {
|
|||
public DataType getDataType(){
|
||||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Test
|
||||
public void regressionTestMLP1() throws Exception {
|
||||
|
||||
|
|
|
@ -64,6 +64,11 @@ public class RegressionTest080 extends BaseDL4JTest {
|
|||
return DataType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Test
|
||||
public void regressionTestMLP1() throws Exception {
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ public class RegressionTest100a extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -52,7 +52,7 @@ public class RegressionTest100b3 extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -71,7 +71,7 @@ public class RegressionTest100b4 extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -58,7 +58,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
|||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -30,6 +30,11 @@ import static org.junit.Assert.assertTrue;
|
|||
*/
|
||||
public class TestDistributionDeserializer extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDistributionDeserializer() throws Exception {
|
||||
//Test current format:
|
||||
|
|
|
@ -46,6 +46,11 @@ public class TestDeepWalk extends BaseDL4JTest {
|
|||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 120_000L; //Increase timeout due to intermittently slow CI machines
|
||||
}
|
||||
|
||||
@Test(timeout = 60000L)
|
||||
public void testBasic() throws IOException {
|
||||
//Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions
|
||||
|
|
|
@ -84,12 +84,6 @@
|
|||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.awaitility</groupId>
|
||||
<artifactId>awaitility</artifactId>
|
||||
<version>4.0.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -968,7 +968,7 @@ public class WordVectorSerializer {
|
|||
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
|
||||
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
|
||||
// first we load syn0
|
||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors));
|
||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors)); //Note stream is closed in loadTxt
|
||||
InMemoryLookupTable lookupTable = pair.getFirst();
|
||||
lookupTable.setNegative(configuration.getNegative());
|
||||
if (configuration.getNegative() > 0)
|
||||
|
@ -1607,7 +1607,7 @@ public class WordVectorSerializer {
|
|||
*/
|
||||
@Deprecated
|
||||
public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
|
||||
FileInputStream fileInputStream = new FileInputStream(vectorsFile);
|
||||
FileInputStream fileInputStream = new FileInputStream(vectorsFile); //Note stream is closed in loadTxt
|
||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(fileInputStream);
|
||||
return fromPair(pair);
|
||||
}
|
||||
|
|
|
@ -49,8 +49,8 @@ import java.io.File;
|
|||
import java.util.Collection;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
import static org.awaitility.Awaitility.await;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
|
||||
@Slf4j
|
||||
|
@ -206,12 +206,6 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
|
|||
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
||||
|
||||
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
||||
await()
|
||||
.until(new Callable<Boolean>() {
|
||||
@Override
|
||||
public Boolean call() {
|
||||
return net.params().equalsWithEps(restored.params(), 2e-3);
|
||||
}
|
||||
});
|
||||
assertTrue(net.params().equalsWithEps(restored.params(), 2e-3));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
|||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.optimize.Solver;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -247,10 +248,8 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
|
|||
@Override
|
||||
public int[] predict(INDArray input) {
|
||||
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
|
||||
int[] ret = new int[input.rows()];
|
||||
for (int i = 0; i < ret.length; i++)
|
||||
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
|
||||
return ret;
|
||||
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
|
||||
return output.argMax(1).toIntVector();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.optimize.Solver;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
@ -251,10 +252,8 @@ public class LossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.LossL
|
|||
@Override
|
||||
public int[] predict(INDArray input) {
|
||||
INDArray output = activate(input, false, LayerWorkspaceMgr.noWorkspacesImmutable());
|
||||
int[] ret = new int[input.rows()];
|
||||
for (int i = 0; i < ret.length; i++)
|
||||
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
|
||||
return ret;
|
||||
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
|
||||
return output.argMax(1).toIntVector();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -2220,14 +2220,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
if (d.size(0) > Integer.MAX_VALUE)
|
||||
throw new ND4JArraySizeException();
|
||||
|
||||
int[] ret = new int[(int) d.size(0)];
|
||||
if (d.isRowVectorOrScalar())
|
||||
ret[0] = Nd4j.getBlasWrapper().iamax(output);
|
||||
else {
|
||||
for (int i = 0; i < ret.length; i++)
|
||||
ret[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
|
||||
}
|
||||
return ret;
|
||||
Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank());
|
||||
return output.argMax(1).toIntVector();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
|||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.ui.model.stats.StatsListener;
|
||||
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
|
||||
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
|
@ -42,8 +43,10 @@ import java.io.IOException;
|
|||
*/
|
||||
public class TestTransferStatsCollection extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90_000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test() throws IOException {
|
||||
|
@ -62,9 +65,7 @@ public class TestTransferStatsCollection extends BaseDL4JTest {
|
|||
new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build())
|
||||
.setFeatureExtractor(0).build();
|
||||
|
||||
File dir = testDir.newFolder();
|
||||
File f = new File(dir, "dl4jTestTransferStatsCollection.bin");
|
||||
net2.setListeners(new StatsListener(new FileStatsStorage(f)));
|
||||
net2.setListeners(new StatsListener(new InMemoryStatsStorage()));
|
||||
|
||||
//Previosuly: failed on frozen layers
|
||||
net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10)));
|
||||
|
|
Loading…
Reference in New Issue