Fixed shape for muli

master
Alexander Stoyakin 2019-10-16 12:59:25 +03:00
parent d5002b14c7
commit c4307384f3
2 changed files with 34 additions and 1 deletions

View File

@ -50,6 +50,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.*;
import static org.junit.Assert.*;
@ -816,6 +817,37 @@ public class Word2VecTests extends BaseDL4JTest {
assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money"));
}
@Test
public void testWordsNearestSum() throws IOException {
log.info("Load & Vectorize Sentences....");
SentenceIterator iter = new BasicLineIterator(inputFile);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
log.info("Building model....");
Word2Vec vec = new Word2Vec.Builder()
.minWordFrequency(5)
.iterations(1)
.layerSize(100)
.seed(42)
.windowSize(5)
.iterate(iter)
.tokenizerFactory(t)
.build();
log.info("Fitting Word2Vec model....");
vec.fit();
log.info("Writing word vectors to text file....");
log.info("Closest Words:");
Collection<String> lst = vec.wordsNearestSum("day", 10);
log.info("10 Words closest to 'day': {}", lst);
assertTrue(lst.contains("week"));
assertTrue(lst.contains("night"));
assertTrue(lst.contains("year"));
assertTrue(lst.contains("years"));
assertTrue(lst.contains("time"));
}
private static void printWords(String target, Collection<String> list, Word2Vec vec) {
System.out.println("Words close to [" + target + "]:");
for (String word : list) {

View File

@ -351,7 +351,8 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
if (lookupTable instanceof InMemoryLookupTable) {
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
INDArray syn0 = l.getSyn0();
INDArray weights = syn0.norm2(0).rdivi(1).muli(words);
INDArray temp = syn0.norm2(0).rdivi(1).reshape(words.shape());
INDArray weights = temp.muli(words);
INDArray distances = syn0.mulRowVector(weights).sum(1);
INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
INDArray sort = sorted[0];