fix: TfidfVectorizer.vectorize() NPE when fitted from LabelAwareIterator
issue #8886 Signed-off-by: Benjamin Possolo <bpossolo@gmail.com>master
parent
722d5a052a
commit
bc02d9f0b0
|
@ -24,6 +24,7 @@ import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
|||
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
||||
import org.deeplearning4j.text.documentiterator.DocumentIterator;
|
||||
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
|
||||
import org.deeplearning4j.text.documentiterator.LabelAwareIteratorWrapper;
|
||||
import org.deeplearning4j.text.documentiterator.LabelsSource;
|
||||
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
|
@ -186,7 +187,7 @@ public class TfidfVectorizer extends BaseTextVectorizer {
|
|||
}
|
||||
|
||||
public Builder setIterator(@NonNull LabelAwareIterator iterator) {
|
||||
this.iterator = iterator;
|
||||
this.iterator = new LabelAwareIteratorWrapper(iterator, labelsSource);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
package org.deeplearning4j.text.documentiterator;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* LabelAwareIterator wrapper which populates a LabelsSource while iterating.
|
||||
*
|
||||
* @author Benjamin Possolo
|
||||
*/
|
||||
public class LabelAwareIteratorWrapper implements LabelAwareIterator {
|
||||
|
||||
private final LabelAwareIterator delegate;
|
||||
private final LabelsSource sink;
|
||||
|
||||
public LabelAwareIteratorWrapper(LabelAwareIterator delegate, LabelsSource sink) {
|
||||
this.delegate = delegate;
|
||||
this.sink = sink;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return delegate.hasNext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNextDocument() {
|
||||
return delegate.hasNextDocument();
|
||||
}
|
||||
|
||||
@Override
|
||||
public LabelsSource getLabelsSource() {
|
||||
return sink;
|
||||
}
|
||||
|
||||
@Override
|
||||
public LabelledDocument next() {
|
||||
return nextDocument();
|
||||
}
|
||||
|
||||
@Override
|
||||
public LabelledDocument nextDocument() {
|
||||
LabelledDocument doc = delegate.nextDocument();
|
||||
List<String> labels = doc.getLabels();
|
||||
if (labels != null) {
|
||||
for (String label : labels) {
|
||||
sink.storeLabel(label);
|
||||
}
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
delegate.reset();
|
||||
sink.reset();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shutdown() {}
|
||||
}
|
|
@ -24,6 +24,10 @@ import org.junit.rules.TemporaryFolder;
|
|||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
|
||||
import org.deeplearning4j.text.documentiterator.LabelledDocument;
|
||||
import org.deeplearning4j.text.documentiterator.LabelsSource;
|
||||
import org.deeplearning4j.text.documentiterator.SimpleLabelAwareIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
|
||||
|
@ -40,6 +44,7 @@ import org.nd4j.linalg.util.SerializationUtils;
|
|||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
@ -131,6 +136,43 @@ public class TfidfVectorizerTest extends BaseDL4JTest {
|
|||
assertEquals(vector, dataSet.getFeatures());
|
||||
}
|
||||
|
||||
public void testTfIdfVectorizerFromLabelAwareIterator() throws Exception {
|
||||
LabelledDocument doc1 = new LabelledDocument();
|
||||
doc1.addLabel("dog");
|
||||
doc1.setContent("it barks like a dog");
|
||||
|
||||
LabelledDocument doc2 = new LabelledDocument();
|
||||
doc2.addLabel("cat");
|
||||
doc2.setContent("it meows like a cat");
|
||||
|
||||
List<LabelledDocument> docs = new ArrayList<>(2);
|
||||
docs.add(doc1);
|
||||
docs.add(doc2);
|
||||
|
||||
LabelAwareIterator iterator = new SimpleLabelAwareIterator(docs);
|
||||
TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
|
||||
|
||||
TfidfVectorizer vectorizer = new TfidfVectorizer
|
||||
.Builder()
|
||||
.setMinWordFrequency(1)
|
||||
.setStopWords(new ArrayList<String>())
|
||||
.setTokenizerFactory(tokenizerFactory)
|
||||
.setIterator(iterator)
|
||||
.allowParallelTokenization(false)
|
||||
.build();
|
||||
|
||||
vectorizer.fit();
|
||||
|
||||
DataSet dataset = vectorizer.vectorize("it meows like a cat", "cat");
|
||||
assertNotNull(dataset);
|
||||
|
||||
LabelsSource source = vectorizer.getLabelsSource();
|
||||
assertEquals(2, source.getNumberOfLabelsUsed());
|
||||
List<String> labels = source.getLabels();
|
||||
assertEquals("dog", labels.get(0));
|
||||
assertEquals("cat", labels.get(1));
|
||||
}
|
||||
|
||||
@Test(timeout = 10000L)
|
||||
public void testParallelFlag1() throws Exception {
|
||||
val vectorizer = new TfidfVectorizer.Builder()
|
||||
|
|
Loading…
Reference in New Issue