fix: TfidfVectorizer.vectorize() NPE when fitted from LabelAwareIterator

issue #8886

Signed-off-by: Benjamin Possolo <bpossolo@gmail.com>
master
Benjamin Possolo 2020-04-26 09:38:32 -07:00
parent 722d5a052a
commit bc02d9f0b0
3 changed files with 104 additions and 1 deletions

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache; import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.documentiterator.DocumentIterator; import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator; import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIteratorWrapper;
import org.deeplearning4j.text.documentiterator.LabelsSource; import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter; import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
@ -186,7 +187,7 @@ public class TfidfVectorizer extends BaseTextVectorizer {
} }
public Builder setIterator(@NonNull LabelAwareIterator iterator) { public Builder setIterator(@NonNull LabelAwareIterator iterator) {
this.iterator = iterator; this.iterator = new LabelAwareIteratorWrapper(iterator, labelsSource);
return this; return this;
} }

View File

@ -0,0 +1,60 @@
package org.deeplearning4j.text.documentiterator;
import java.util.List;
/**
* LabelAwareIterator wrapper which populates a LabelsSource while iterating.
*
* @author Benjamin Possolo
*/
public class LabelAwareIteratorWrapper implements LabelAwareIterator {
private final LabelAwareIterator delegate;
private final LabelsSource sink;
public LabelAwareIteratorWrapper(LabelAwareIterator delegate, LabelsSource sink) {
this.delegate = delegate;
this.sink = sink;
}
@Override
public boolean hasNext() {
return delegate.hasNext();
}
@Override
public boolean hasNextDocument() {
return delegate.hasNextDocument();
}
@Override
public LabelsSource getLabelsSource() {
return sink;
}
@Override
public LabelledDocument next() {
return nextDocument();
}
@Override
public LabelledDocument nextDocument() {
LabelledDocument doc = delegate.nextDocument();
List<String> labels = doc.getLabels();
if (labels != null) {
for (String label : labels) {
sink.storeLabel(label);
}
}
return doc;
}
@Override
public void reset() {
delegate.reset();
sink.reset();
}
@Override
public void shutdown() {}
}

View File

@ -24,6 +24,10 @@ import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.deeplearning4j.text.documentiterator.SimpleLabelAwareIterator;
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator; import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator; import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator; import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
@ -40,6 +44,7 @@ import org.nd4j.linalg.util.SerializationUtils;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -131,6 +136,43 @@ public class TfidfVectorizerTest extends BaseDL4JTest {
assertEquals(vector, dataSet.getFeatures()); assertEquals(vector, dataSet.getFeatures());
} }
public void testTfIdfVectorizerFromLabelAwareIterator() throws Exception {
LabelledDocument doc1 = new LabelledDocument();
doc1.addLabel("dog");
doc1.setContent("it barks like a dog");
LabelledDocument doc2 = new LabelledDocument();
doc2.addLabel("cat");
doc2.setContent("it meows like a cat");
List<LabelledDocument> docs = new ArrayList<>(2);
docs.add(doc1);
docs.add(doc2);
LabelAwareIterator iterator = new SimpleLabelAwareIterator(docs);
TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
TfidfVectorizer vectorizer = new TfidfVectorizer
.Builder()
.setMinWordFrequency(1)
.setStopWords(new ArrayList<String>())
.setTokenizerFactory(tokenizerFactory)
.setIterator(iterator)
.allowParallelTokenization(false)
.build();
vectorizer.fit();
DataSet dataset = vectorizer.vectorize("it meows like a cat", "cat");
assertNotNull(dataset);
LabelsSource source = vectorizer.getLabelsSource();
assertEquals(2, source.getNumberOfLabelsUsed());
List<String> labels = source.getLabels();
assertEquals("dog", labels.get(0));
assertEquals("cat", labels.get(1));
}
@Test(timeout = 10000L) @Test(timeout = 10000L)
public void testParallelFlag1() throws Exception { public void testParallelFlag1() throws Exception {
val vectorizer = new TfidfVectorizer.Builder() val vectorizer = new TfidfVectorizer.Builder()