Merge remote-tracking branch 'eclipse/master'
commit
dad6bc5ed2
|
@ -24,6 +24,7 @@ import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
||||||
import org.deeplearning4j.text.documentiterator.DocumentIterator;
|
import org.deeplearning4j.text.documentiterator.DocumentIterator;
|
||||||
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
|
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
|
||||||
|
import org.deeplearning4j.text.documentiterator.LabelAwareIteratorWrapper;
|
||||||
import org.deeplearning4j.text.documentiterator.LabelsSource;
|
import org.deeplearning4j.text.documentiterator.LabelsSource;
|
||||||
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
|
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
|
||||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||||
|
@ -186,7 +187,7 @@ public class TfidfVectorizer extends BaseTextVectorizer {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder setIterator(@NonNull LabelAwareIterator iterator) {
|
public Builder setIterator(@NonNull LabelAwareIterator iterator) {
|
||||||
this.iterator = iterator;
|
this.iterator = new LabelAwareIteratorWrapper(iterator, labelsSource);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
package org.deeplearning4j.text.documentiterator;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LabelAwareIterator wrapper which populates a LabelsSource while iterating.
|
||||||
|
*
|
||||||
|
* @author Benjamin Possolo
|
||||||
|
*/
|
||||||
|
public class LabelAwareIteratorWrapper implements LabelAwareIterator {
|
||||||
|
|
||||||
|
private final LabelAwareIterator delegate;
|
||||||
|
private final LabelsSource sink;
|
||||||
|
|
||||||
|
public LabelAwareIteratorWrapper(LabelAwareIterator delegate, LabelsSource sink) {
|
||||||
|
this.delegate = delegate;
|
||||||
|
this.sink = sink;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
return delegate.hasNext();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNextDocument() {
|
||||||
|
return delegate.hasNextDocument();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LabelsSource getLabelsSource() {
|
||||||
|
return sink;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LabelledDocument next() {
|
||||||
|
return nextDocument();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LabelledDocument nextDocument() {
|
||||||
|
LabelledDocument doc = delegate.nextDocument();
|
||||||
|
List<String> labels = doc.getLabels();
|
||||||
|
if (labels != null) {
|
||||||
|
for (String label : labels) {
|
||||||
|
sink.storeLabel(label);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
delegate.reset();
|
||||||
|
sink.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void shutdown() {}
|
||||||
|
}
|
|
@ -24,6 +24,10 @@ import org.junit.rules.TemporaryFolder;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
|
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
|
||||||
|
import org.deeplearning4j.text.documentiterator.LabelledDocument;
|
||||||
|
import org.deeplearning4j.text.documentiterator.LabelsSource;
|
||||||
|
import org.deeplearning4j.text.documentiterator.SimpleLabelAwareIterator;
|
||||||
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
|
||||||
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator;
|
||||||
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
|
||||||
|
@ -40,6 +44,7 @@ import org.nd4j.linalg.util.SerializationUtils;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
@ -131,6 +136,43 @@ public class TfidfVectorizerTest extends BaseDL4JTest {
|
||||||
assertEquals(vector, dataSet.getFeatures());
|
assertEquals(vector, dataSet.getFeatures());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testTfIdfVectorizerFromLabelAwareIterator() throws Exception {
|
||||||
|
LabelledDocument doc1 = new LabelledDocument();
|
||||||
|
doc1.addLabel("dog");
|
||||||
|
doc1.setContent("it barks like a dog");
|
||||||
|
|
||||||
|
LabelledDocument doc2 = new LabelledDocument();
|
||||||
|
doc2.addLabel("cat");
|
||||||
|
doc2.setContent("it meows like a cat");
|
||||||
|
|
||||||
|
List<LabelledDocument> docs = new ArrayList<>(2);
|
||||||
|
docs.add(doc1);
|
||||||
|
docs.add(doc2);
|
||||||
|
|
||||||
|
LabelAwareIterator iterator = new SimpleLabelAwareIterator(docs);
|
||||||
|
TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
|
||||||
|
|
||||||
|
TfidfVectorizer vectorizer = new TfidfVectorizer
|
||||||
|
.Builder()
|
||||||
|
.setMinWordFrequency(1)
|
||||||
|
.setStopWords(new ArrayList<String>())
|
||||||
|
.setTokenizerFactory(tokenizerFactory)
|
||||||
|
.setIterator(iterator)
|
||||||
|
.allowParallelTokenization(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
vectorizer.fit();
|
||||||
|
|
||||||
|
DataSet dataset = vectorizer.vectorize("it meows like a cat", "cat");
|
||||||
|
assertNotNull(dataset);
|
||||||
|
|
||||||
|
LabelsSource source = vectorizer.getLabelsSource();
|
||||||
|
assertEquals(2, source.getNumberOfLabelsUsed());
|
||||||
|
List<String> labels = source.getLabels();
|
||||||
|
assertEquals("dog", labels.get(0));
|
||||||
|
assertEquals("cat", labels.get(1));
|
||||||
|
}
|
||||||
|
|
||||||
@Test(timeout = 10000L)
|
@Test(timeout = 10000L)
|
||||||
public void testParallelFlag1() throws Exception {
|
public void testParallelFlag1() throws Exception {
|
||||||
val vectorizer = new TfidfVectorizer.Builder()
|
val vectorizer = new TfidfVectorizer.Builder()
|
||||||
|
|
Loading…
Reference in New Issue