Test fixes (#218)
* Test speedups / integration test run only for CUDA - NLP Signed-off-by: AlexDBlack <blacka101@gmail.com> * nlp-uima CUDA slow tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Spark CUDA timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
948646b32d
commit
ce6848c9fe
|
@ -16,25 +16,19 @@
|
||||||
|
|
||||||
package org.deeplearning4j.models;
|
package org.deeplearning4j.models;
|
||||||
|
|
||||||
import org.junit.rules.Timeout;
|
|
||||||
import org.nd4j.shade.guava.io.Files;
|
|
||||||
import org.nd4j.shade.guava.primitives.Doubles;
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.lang.ArrayUtils;
|
import org.apache.commons.lang.ArrayUtils;
|
||||||
import org.apache.commons.lang3.RandomUtils;
|
import org.apache.commons.lang3.RandomUtils;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
|
||||||
import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory;
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
|
||||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
||||||
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
|
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
|
||||||
|
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||||
|
import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory;
|
||||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
|
@ -48,11 +42,16 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFac
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.junit.rules.Timeout;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import org.nd4j.resources.Resources;
|
import org.nd4j.resources.Resources;
|
||||||
|
import org.nd4j.shade.guava.primitives.Doubles;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
@ -272,7 +271,14 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFullModelSerialization() throws Exception {
|
public void testFullModelSerialization() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||||
|
|
||||||
|
|
||||||
SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
|
SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
@ -892,5 +898,4 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
fail(e.getMessage());
|
fail(e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -159,6 +159,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWord2VecCBOW() throws Exception {
|
public void testWord2VecCBOW() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
@ -188,6 +193,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWord2VecMultiEpoch() throws Exception {
|
public void testWord2VecMultiEpoch() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter;
|
SentenceIterator iter;
|
||||||
if(isIntegrationTests()){
|
if(isIntegrationTests()){
|
||||||
iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
@ -220,6 +230,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void reproducibleResults_ForMultipleRuns() throws Exception {
|
public void reproducibleResults_ForMultipleRuns() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
log.info("reproducibleResults_ForMultipleRuns");
|
log.info("reproducibleResults_ForMultipleRuns");
|
||||||
val shakespear = new ClassPathResource("big/rnj.txt");
|
val shakespear = new ClassPathResource("big/rnj.txt");
|
||||||
val basic = new ClassPathResource("big/rnj.txt");
|
val basic = new ClassPathResource("big/rnj.txt");
|
||||||
|
@ -274,6 +289,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRunWord2Vec() throws Exception {
|
public void testRunWord2Vec() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
// Strip white space before and after for each line
|
// Strip white space before and after for each line
|
||||||
/*val shakespear = new ClassPathResource("big/rnj.txt");
|
/*val shakespear = new ClassPathResource("big/rnj.txt");
|
||||||
SentenceIterator iter = new BasicLineIterator(shakespear.getFile());*/
|
SentenceIterator iter = new BasicLineIterator(shakespear.getFile());*/
|
||||||
|
@ -363,6 +383,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLoadingWordVectors() throws Exception {
|
public void testLoadingWordVectors() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
File modelFile = new File(pathToWriteto);
|
File modelFile = new File(pathToWriteto);
|
||||||
if (!modelFile.exists()) {
|
if (!modelFile.exists()) {
|
||||||
testRunWord2Vec();
|
testRunWord2Vec();
|
||||||
|
@ -396,6 +421,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testW2VnegativeOnRestore() throws Exception {
|
public void testW2VnegativeOnRestore() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
// Strip white space before and after for each line
|
// Strip white space before and after for each line
|
||||||
SentenceIterator iter;
|
SentenceIterator iter;
|
||||||
if(isIntegrationTests()){
|
if(isIntegrationTests()){
|
||||||
|
@ -453,6 +483,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUnknown1() throws Exception {
|
public void testUnknown1() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
// Strip white space before and after for each line
|
// Strip white space before and after for each line
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
|
@ -688,6 +723,10 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordVectorsPartiallyAbsentLabels() throws Exception {
|
public void testWordVectorsPartiallyAbsentLabels() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
|
@ -720,6 +759,10 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordVectorsAbsentLabels() throws Exception {
|
public void testWordVectorsAbsentLabels() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
|
@ -745,6 +788,10 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordVectorsAbsentLabels_WithUnknown() throws Exception {
|
public void testWordVectorsAbsentLabels_WithUnknown() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
|
@ -814,6 +861,10 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void weightsNotUpdated_WhenLocked_CBOW() throws Exception {
|
public void weightsNotUpdated_WhenLocked_CBOW() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
|
||||||
|
@ -851,6 +902,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordsNearestSum() throws IOException {
|
public void testWordsNearestSum() throws IOException {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
log.info("Load & Vectorize Sentences....");
|
log.info("Load & Vectorize Sentences....");
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile);
|
SentenceIterator iter = new BasicLineIterator(inputFile);
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
|
|
@ -48,12 +48,22 @@ public class TsneTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 60000L;
|
return 180000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDefaultFPDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSimple() throws Exception {
|
public void testSimple() throws Exception {
|
||||||
//Simple sanity check
|
//Simple sanity check
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.Par
|
||||||
import org.deeplearning4j.text.sentenceiterator.*;
|
import org.deeplearning4j.text.sentenceiterator.*;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
|
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
|
||||||
|
@ -80,12 +81,21 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 240000;
|
return isIntegrationTests() ? 600_000 : 240_000;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDefaultFPDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@Test
|
@Test
|
||||||
|
@ -359,8 +369,13 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
@Test
|
||||||
public void testParagraphVectorsDM() throws Exception {
|
public void testParagraphVectorsDM() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed
|
||||||
|
}
|
||||||
|
|
||||||
File file = Resources.asFile("/big/raw_sentences.txt");
|
File file = Resources.asFile("/big/raw_sentences.txt");
|
||||||
SentenceIterator iter = new BasicLineIterator(file);
|
SentenceIterator iter = new BasicLineIterator(file);
|
||||||
|
|
||||||
|
@ -372,10 +387,10 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
LabelsSource source = new LabelsSource("DOC_");
|
LabelsSource source = new LabelsSource("DOC_");
|
||||||
|
|
||||||
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(1)
|
ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(1)
|
||||||
.layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter)
|
.layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter)
|
||||||
.trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0)
|
.trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0)
|
||||||
.useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true)
|
.useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true)
|
||||||
.sequenceLearningAlgorithm(new DM<VocabWord>()).build();
|
.sequenceLearningAlgorithm(new DM<VocabWord>()).build();
|
||||||
|
|
||||||
vec.fit();
|
vec.fit();
|
||||||
|
|
||||||
|
@ -404,7 +419,9 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
double similarityX = vec.similarity("DOC_3720", "DOC_9852");
|
double similarityX = vec.similarity("DOC_3720", "DOC_9852");
|
||||||
log.info("3720/9852 similarity: " + similarityX);
|
log.info("3720/9852 similarity: " + similarityX);
|
||||||
assertTrue(similarityX < 0.5d);
|
if(isIntegrationTests()) {
|
||||||
|
assertTrue(similarityX < 0.5d);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// testing DM inference now
|
// testing DM inference now
|
||||||
|
@ -418,7 +435,6 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
log.info("Cos O/A: {}", cosAO1);
|
log.info("Cos O/A: {}", cosAO1);
|
||||||
log.info("Cos A/B: {}", cosAB1);
|
log.info("Cos A/B: {}", cosAB1);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -501,6 +517,11 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
@Test(timeout = 300000)
|
||||||
public void testParagraphVectorsWithWordVectorsModelling1() throws Exception {
|
public void testParagraphVectorsWithWordVectorsModelling1() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed
|
||||||
|
}
|
||||||
|
|
||||||
File file = Resources.asFile("/big/raw_sentences.txt");
|
File file = Resources.asFile("/big/raw_sentences.txt");
|
||||||
SentenceIterator iter = new BasicLineIterator(file);
|
SentenceIterator iter = new BasicLineIterator(file);
|
||||||
|
|
||||||
|
@ -705,8 +726,12 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
In this test we'll build w2v model, and will use it's vocab and weights for ParagraphVectors.
|
In this test we'll build w2v model, and will use it's vocab and weights for ParagraphVectors.
|
||||||
there's no need in this test within travis, use it manually only for problems detection
|
there's no need in this test within travis, use it manually only for problems detection
|
||||||
*/
|
*/
|
||||||
@Test(timeout = 300000)
|
@Test
|
||||||
public void testParagraphVectorsOverExistingWordVectorsModel() throws Exception {
|
public void testParagraphVectorsOverExistingWordVectorsModel() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed
|
||||||
|
}
|
||||||
|
|
||||||
// we build w2v from multiple sources, to cover everything
|
// we build w2v from multiple sources, to cover everything
|
||||||
File resource_sentences = Resources.asFile("/big/raw_sentences.txt");
|
File resource_sentences = Resources.asFile("/big/raw_sentences.txt");
|
||||||
|
@ -997,14 +1022,18 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
log.info("SimilarityB: {}", simB);
|
log.info("SimilarityB: {}", simB);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
@Test
|
||||||
|
@Ignore //AB 2020/02/06 - https://github.com/eclipse/deeplearning4j/issues/8677
|
||||||
public void testDirectInference() throws Exception {
|
public void testDirectInference() throws Exception {
|
||||||
File resource_sentences = Resources.asFile("/big/raw_sentences.txt");
|
boolean isIntegration = isIntegrationTests();
|
||||||
|
File resource = Resources.asFile("/big/raw_sentences.txt");
|
||||||
|
SentenceIterator sentencesIter = getIterator(isIntegration, resource);
|
||||||
|
|
||||||
ClassPathResource resource_mixed = new ClassPathResource("paravec/");
|
ClassPathResource resource_mixed = new ClassPathResource("paravec/");
|
||||||
File local_resource_mixed = testDir.newFolder();
|
File local_resource_mixed = testDir.newFolder();
|
||||||
resource_mixed.copyDirectory(local_resource_mixed);
|
resource_mixed.copyDirectory(local_resource_mixed);
|
||||||
SentenceIterator iter = new AggregatingSentenceIterator.Builder()
|
SentenceIterator iter = new AggregatingSentenceIterator.Builder()
|
||||||
.addSentenceIterator(new BasicLineIterator(resource_sentences))
|
.addSentenceIterator(sentencesIter)
|
||||||
.addSentenceIterator(new FileSentenceIterator(local_resource_mixed)).build();
|
.addSentenceIterator(new FileSentenceIterator(local_resource_mixed)).build();
|
||||||
|
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
@ -1154,24 +1183,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
public void testDoubleFit() throws Exception {
|
public void testDoubleFit() throws Exception {
|
||||||
boolean isIntegration = isIntegrationTests();
|
boolean isIntegration = isIntegrationTests();
|
||||||
File resource = Resources.asFile("/big/raw_sentences.txt");
|
File resource = Resources.asFile("/big/raw_sentences.txt");
|
||||||
SentenceIterator iter;
|
SentenceIterator iter = getIterator(isIntegration, resource);
|
||||||
if(isIntegration){
|
|
||||||
iter = new BasicLineIterator(resource);
|
|
||||||
} else {
|
|
||||||
List<String> lines = new ArrayList<>();
|
|
||||||
try(InputStream is = new BufferedInputStream(new FileInputStream(resource))){
|
|
||||||
LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
|
|
||||||
try{
|
|
||||||
for( int i=0; i<500 && lineIter.hasNext(); i++ ){
|
|
||||||
lines.add(lineIter.next());
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
lineIter.close();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
iter = new CollectionSentenceIterator(lines);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
@ -1197,6 +1209,30 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(num1, num2);
|
assertEquals(num1, num2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static SentenceIterator getIterator(boolean isIntegration, File file) throws IOException {
|
||||||
|
return getIterator(isIntegration, file, 500);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static SentenceIterator getIterator(boolean isIntegration, File file, int linesForUnitTest) throws IOException {
|
||||||
|
if(isIntegration){
|
||||||
|
return new BasicLineIterator(file);
|
||||||
|
} else {
|
||||||
|
List<String> lines = new ArrayList<>();
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(file))){
|
||||||
|
LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
|
||||||
|
try{
|
||||||
|
for( int i=0; i<linesForUnitTest && lineIter.hasNext(); i++ ){
|
||||||
|
lines.add(lineIter.next());
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
lineIter.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CollectionSentenceIterator(lines);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.models.word2vec;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.models.paragraphvectors.ParagraphVectorsTest;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
@ -56,6 +57,11 @@ import static org.junit.Assert.assertEquals;
|
||||||
public class Word2VecTestsSmall extends BaseDL4JTest {
|
public class Word2VecTestsSmall extends BaseDL4JTest {
|
||||||
WordVectors word2vec;
|
WordVectors word2vec;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return isIntegrationTests() ? 240000 : 60000;
|
||||||
|
}
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() throws Exception {
|
public void setUp() throws Exception {
|
||||||
word2vec = WordVectorSerializer.readWord2VecModel(new ClassPathResource("vec.bin").getFile());
|
word2vec = WordVectorSerializer.readWord2VecModel(new ClassPathResource("vec.bin").getFile());
|
||||||
|
@ -85,8 +91,8 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
|
||||||
@Test(timeout = 300000)
|
@Test(timeout = 300000)
|
||||||
public void testUnkSerialization_1() throws Exception {
|
public void testUnkSerialization_1() throws Exception {
|
||||||
val inputFile = Resources.asFile("big/raw_sentences.txt");
|
val inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||||
|
// val iter = new BasicLineIterator(inputFile);
|
||||||
val iter = new BasicLineIterator(inputFile);
|
val iter = ParagraphVectorsTest.getIterator(isIntegrationTests(), inputFile);
|
||||||
val t = new DefaultTokenizerFactory();
|
val t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
|
||||||
|
@ -147,8 +153,8 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
|
|
||||||
val inputFile = Resources.asFile("big/raw_sentences.txt");
|
val inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||||
|
val iter = ParagraphVectorsTest.getIterator(isIntegrationTests(), inputFile);
|
||||||
val iter = new BasicLineIterator(inputFile);
|
// val iter = new BasicLineIterator(inputFile);
|
||||||
val t = new DefaultTokenizerFactory();
|
val t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.models.word2vec.iterator;
|
package org.deeplearning4j.models.word2vec.iterator;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.models.paragraphvectors.ParagraphVectorsTest;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
|
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
|
||||||
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||||
|
@ -59,7 +60,8 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
|
||||||
public void testIterator1() throws Exception {
|
public void testIterator1() throws Exception {
|
||||||
|
|
||||||
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = ParagraphVectorsTest.getIterator(isIntegrationTests(), inputFile);
|
||||||
|
// SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
|
|
@ -58,6 +58,7 @@ import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.evaluation.classification.ROC;
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
@ -93,7 +94,23 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 120000L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDefaultFPDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
public void testFromSvmLightBackprop() throws Exception {
|
public void testFromSvmLightBackprop() throws Exception {
|
||||||
JavaRDD<LabeledPoint> data = MLUtils
|
JavaRDD<LabeledPoint> data = MLUtils
|
||||||
.loadLibSVMFile(sc.sc(),
|
.loadLibSVMFile(sc.sc(),
|
||||||
|
@ -125,7 +142,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testFromSvmLight() throws Exception {
|
public void testFromSvmLight() throws Exception {
|
||||||
JavaRDD<LabeledPoint> data = MLUtils
|
JavaRDD<LabeledPoint> data = MLUtils
|
||||||
.loadLibSVMFile(sc.sc(),
|
.loadLibSVMFile(sc.sc(),
|
||||||
|
@ -155,7 +172,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
master.fitLabeledPoint(data);
|
master.fitLabeledPoint(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testRunIteration() {
|
public void testRunIteration() {
|
||||||
|
|
||||||
DataSet dataSet = new IrisDataSetIterator(5, 5).next();
|
DataSet dataSet = new IrisDataSetIterator(5, 5).next();
|
||||||
|
@ -175,7 +192,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
assertEquals(expectedParams.size(1), actualParams.size(1));
|
assertEquals(expectedParams.size(1), actualParams.size(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testUpdaters() {
|
public void testUpdaters() {
|
||||||
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
||||||
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
||||||
|
@ -197,7 +214,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testEvaluation() {
|
public void testEvaluation() {
|
||||||
|
|
||||||
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
||||||
|
@ -228,7 +245,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testSmallAmountOfData() {
|
public void testSmallAmountOfData() {
|
||||||
//Idea: Test spark training where some executors don't get any data
|
//Idea: Test spark training where some executors don't get any data
|
||||||
//in this case: by having fewer examples (2 DataSets) than executors (local[*])
|
//in this case: by having fewer examples (2 DataSets) than executors (local[*])
|
||||||
|
@ -255,7 +272,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testDistributedScoring() {
|
public void testDistributedScoring() {
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1)
|
||||||
|
@ -333,7 +350,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testParameterAveragingMultipleExamplesPerDataSet() throws Exception {
|
public void testParameterAveragingMultipleExamplesPerDataSet() throws Exception {
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 25;
|
||||||
|
@ -382,7 +399,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testFitViaStringPaths() throws Exception {
|
public void testFitViaStringPaths() throws Exception {
|
||||||
|
|
||||||
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPaths").toPath();
|
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPaths").toPath();
|
||||||
|
@ -445,7 +462,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
sparkNet.getTrainingMaster().deleteTempFiles(sc);
|
sparkNet.getTrainingMaster().deleteTempFiles(sc);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testFitViaStringPathsSize1() throws Exception {
|
public void testFitViaStringPathsSize1() throws Exception {
|
||||||
|
|
||||||
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsSize1").toPath();
|
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsSize1").toPath();
|
||||||
|
@ -525,7 +542,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testFitViaStringPathsCompGraph() throws Exception {
|
public void testFitViaStringPathsCompGraph() throws Exception {
|
||||||
|
|
||||||
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsCG").toPath();
|
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsCG").toPath();
|
||||||
|
@ -618,7 +635,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
@Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue")
|
@Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue")
|
||||||
public void testSeedRepeatability() throws Exception {
|
public void testSeedRepeatability() throws Exception {
|
||||||
|
|
||||||
|
@ -691,7 +708,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testIterationCounts() throws Exception {
|
public void testIterationCounts() throws Exception {
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 25;
|
||||||
|
@ -737,7 +754,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testIterationCountsGraph() throws Exception {
|
public void testIterationCountsGraph() throws Exception {
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 25;
|
||||||
|
@ -783,7 +800,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L) @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
|
@Test
|
||||||
|
@Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
|
||||||
public void testVaePretrainSimple() {
|
public void testVaePretrainSimple() {
|
||||||
//Simple sanity check on pretraining
|
//Simple sanity check on pretraining
|
||||||
int nIn = 8;
|
int nIn = 8;
|
||||||
|
@ -818,7 +836,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
sparkNet.fit(data);
|
sparkNet.fit(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L) @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
|
@Test
|
||||||
|
@Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
|
||||||
public void testVaePretrainSimpleCG() {
|
public void testVaePretrainSimpleCG() {
|
||||||
//Simple sanity check on pretraining
|
//Simple sanity check on pretraining
|
||||||
int nIn = 8;
|
int nIn = 8;
|
||||||
|
@ -854,7 +873,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testROC() {
|
public void testROC() {
|
||||||
|
|
||||||
int nArrays = 100;
|
int nArrays = 100;
|
||||||
|
@ -909,7 +928,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testROCMultiClass() {
|
public void testROCMultiClass() {
|
||||||
|
|
||||||
int nArrays = 100;
|
int nArrays = 100;
|
||||||
|
|
Loading…
Reference in New Issue