From bc02d9f0b0f5d94d92428d719c99871452408c8e Mon Sep 17 00:00:00 2001 From: Benjamin Possolo Date: Sun, 26 Apr 2020 09:38:32 -0700 Subject: [PATCH] fix: TfidfVectorizer.vectorize() NPE when fitted from LabelAwareIterator issue #8886 Signed-off-by: Benjamin Possolo --- .../vectorizer/TfidfVectorizer.java | 3 +- .../LabelAwareIteratorWrapper.java | 60 +++++++++++++++++++ .../vectorizer/TfidfVectorizerTest.java | 42 +++++++++++++ 3 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/documentiterator/LabelAwareIteratorWrapper.java diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer.java index 0b07ef3b9..92fb83755 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer.java @@ -24,6 +24,7 @@ import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache; import org.deeplearning4j.text.documentiterator.DocumentIterator; import org.deeplearning4j.text.documentiterator.LabelAwareIterator; +import org.deeplearning4j.text.documentiterator.LabelAwareIteratorWrapper; import org.deeplearning4j.text.documentiterator.LabelsSource; import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter; import org.deeplearning4j.text.sentenceiterator.SentenceIterator; @@ -186,7 +187,7 @@ public class TfidfVectorizer extends BaseTextVectorizer { } public Builder setIterator(@NonNull LabelAwareIterator iterator) { - this.iterator = iterator; + this.iterator = new LabelAwareIteratorWrapper(iterator, labelsSource); return this; } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/documentiterator/LabelAwareIteratorWrapper.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/documentiterator/LabelAwareIteratorWrapper.java new file mode 100644 index 000000000..22db35a94 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/documentiterator/LabelAwareIteratorWrapper.java @@ -0,0 +1,60 @@ +package org.deeplearning4j.text.documentiterator; + +import java.util.List; + +/** + * LabelAwareIterator wrapper which populates a LabelsSource while iterating. + * + * @author Benjamin Possolo + */ +public class LabelAwareIteratorWrapper implements LabelAwareIterator { + + private final LabelAwareIterator delegate; + private final LabelsSource sink; + + public LabelAwareIteratorWrapper(LabelAwareIterator delegate, LabelsSource sink) { + this.delegate = delegate; + this.sink = sink; + } + + @Override + public boolean hasNext() { + return delegate.hasNext(); + } + + @Override + public boolean hasNextDocument() { + return delegate.hasNextDocument(); + } + + @Override + public LabelsSource getLabelsSource() { + return sink; + } + + @Override + public LabelledDocument next() { + return nextDocument(); + } + + @Override + public LabelledDocument nextDocument() { + LabelledDocument doc = delegate.nextDocument(); + List labels = doc.getLabels(); + if (labels != null) { + for (String label : labels) { + sink.storeLabel(label); + } + } + return doc; + } + + @Override + public void reset() { + delegate.reset(); + sink.reset(); + } + + @Override + public void shutdown() {} +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizerTest.java index 37b07d412..23a14b305 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizerTest.java @@ -24,6 +24,10 @@ import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.io.ClassPathResource; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; +import org.deeplearning4j.text.documentiterator.LabelAwareIterator; +import org.deeplearning4j.text.documentiterator.LabelledDocument; +import org.deeplearning4j.text.documentiterator.LabelsSource; +import org.deeplearning4j.text.documentiterator.SimpleLabelAwareIterator; import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator; import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator; import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator; @@ -40,6 +44,7 @@ import org.nd4j.linalg.util.SerializationUtils; import java.io.File; import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.concurrent.atomic.AtomicLong; import static org.junit.Assert.*; @@ -131,6 +136,43 @@ public class TfidfVectorizerTest extends BaseDL4JTest { assertEquals(vector, dataSet.getFeatures()); } + public void testTfIdfVectorizerFromLabelAwareIterator() throws Exception { + LabelledDocument doc1 = new LabelledDocument(); + doc1.addLabel("dog"); + doc1.setContent("it barks like a dog"); + + LabelledDocument doc2 = new LabelledDocument(); + doc2.addLabel("cat"); + doc2.setContent("it meows like a cat"); + + List docs = new ArrayList<>(2); + docs.add(doc1); + docs.add(doc2); + + LabelAwareIterator iterator = new SimpleLabelAwareIterator(docs); + TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory(); + + TfidfVectorizer vectorizer = new TfidfVectorizer + .Builder() + .setMinWordFrequency(1) + .setStopWords(new ArrayList()) + .setTokenizerFactory(tokenizerFactory) + .setIterator(iterator) + .allowParallelTokenization(false) + .build(); + + vectorizer.fit(); + + DataSet dataset = vectorizer.vectorize("it meows like a cat", "cat"); + assertNotNull(dataset); + + LabelsSource source = vectorizer.getLabelsSource(); + assertEquals(2, source.getNumberOfLabelsUsed()); + List labels = source.getLabels(); + assertEquals("dog", labels.get(0)); + assertEquals("cat", labels.get(1)); + } + @Test(timeout = 10000L) public void testParallelFlag1() throws Exception { val vectorizer = new TfidfVectorizer.Builder()