Small test fix (#216)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-02 16:44:57 +10:00 committed by GitHub
parent b3a134b608
commit acf559425a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 4 deletions

View File

@ -25,7 +25,6 @@ import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction; import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction;
import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunctionAdapter;
import org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction; import org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction;
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec; import org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec;
import org.deeplearning4j.spark.text.functions.CountCumSum; import org.deeplearning4j.spark.text.functions.CountCumSum;
@ -470,11 +469,11 @@ public class TextPipelineTest extends BaseSparkTest {
Iterator<Tuple2<List<VocabWord>, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator(); Iterator<Tuple2<List<VocabWord>, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator();
FirstIterationFunctionAdapter firstIterationFunction = new FirstIterationFunctionAdapter( FirstIterationFunction firstIterationFunction = new FirstIterationFunction(
word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache()); word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache());
Iterable<Map.Entry<VocabWord, INDArray>> ret = firstIterationFunction.call(iterator); Iterator<Map.Entry<VocabWord, INDArray>> ret = firstIterationFunction.call(iterator);
assertTrue(ret.iterator().hasNext()); assertTrue(ret.hasNext());
} }
@Test @Test