Fixed shape for muli
parent
d5002b14c7
commit
c4307384f3
|
@ -50,6 +50,7 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
@ -816,6 +817,37 @@ public class Word2VecTests extends BaseDL4JTest {
|
|||
assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWordsNearestSum() throws IOException {
|
||||
log.info("Load & Vectorize Sentences....");
|
||||
SentenceIterator iter = new BasicLineIterator(inputFile);
|
||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
||||
log.info("Building model....");
|
||||
Word2Vec vec = new Word2Vec.Builder()
|
||||
.minWordFrequency(5)
|
||||
.iterations(1)
|
||||
.layerSize(100)
|
||||
.seed(42)
|
||||
.windowSize(5)
|
||||
.iterate(iter)
|
||||
.tokenizerFactory(t)
|
||||
.build();
|
||||
|
||||
log.info("Fitting Word2Vec model....");
|
||||
vec.fit();
|
||||
log.info("Writing word vectors to text file....");
|
||||
log.info("Closest Words:");
|
||||
Collection<String> lst = vec.wordsNearestSum("day", 10);
|
||||
log.info("10 Words closest to 'day': {}", lst);
|
||||
assertTrue(lst.contains("week"));
|
||||
assertTrue(lst.contains("night"));
|
||||
assertTrue(lst.contains("year"));
|
||||
assertTrue(lst.contains("years"));
|
||||
assertTrue(lst.contains("time"));
|
||||
}
|
||||
|
||||
private static void printWords(String target, Collection<String> list, Word2Vec vec) {
|
||||
System.out.println("Words close to [" + target + "]:");
|
||||
for (String word : list) {
|
||||
|
|
|
@ -351,7 +351,8 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
|
|||
if (lookupTable instanceof InMemoryLookupTable) {
|
||||
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
|
||||
INDArray syn0 = l.getSyn0();
|
||||
INDArray weights = syn0.norm2(0).rdivi(1).muli(words);
|
||||
INDArray temp = syn0.norm2(0).rdivi(1).reshape(words.shape());
|
||||
INDArray weights = temp.muli(words);
|
||||
INDArray distances = syn0.mulRowVector(weights).sum(1);
|
||||
INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
|
||||
INDArray sort = sorted[0];
|
||||
|
|
Loading…
Reference in New Issue