489 lines
19 KiB
Java
Raw Normal View History

2021-02-01 14:31:20 +09:00
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
2021-02-01 17:47:29 +09:00
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
2021-02-01 14:31:20 +09:00
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
2019-06-06 15:21:15 +03:00
package org.deeplearning4j.models.sequencevectors;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.Writable;
import org.deeplearning4j.BaseDL4JTest;
2019-06-06 15:21:15 +03:00
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
import org.deeplearning4j.models.sequencevectors.graph.enums.NoEdgeHandling;
import org.deeplearning4j.models.sequencevectors.graph.enums.PopularityMode;
import org.deeplearning4j.models.sequencevectors.graph.enums.SpreadSpectrum;
import org.deeplearning4j.models.sequencevectors.graph.enums.WalkDirection;
import org.deeplearning4j.models.sequencevectors.graph.primitives.Graph;
import org.deeplearning4j.models.sequencevectors.graph.primitives.Vertex;
import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker;
import org.deeplearning4j.models.sequencevectors.graph.walkers.impl.PopularityWalker;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceElementFactory;
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.sequencevectors.serialization.AbstractElementFactory;
import org.deeplearning4j.models.sequencevectors.transformers.impl.GraphTransformer;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
2021-03-16 11:57:24 +09:00
import org.junit.jupiter.api.BeforeEach;
2021-03-20 19:06:24 +09:00
import org.junit.jupiter.api.Tag;
2021-03-16 11:57:24 +09:00
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
2019-06-06 15:21:15 +03:00
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
2021-03-16 11:57:24 +09:00
import static org.junit.jupiter.api.Assertions.*;
2019-06-06 15:21:15 +03:00
//@Ignore
public class SequenceVectorsTest extends BaseDL4JTest {
2019-06-06 15:21:15 +03:00
protected static final Logger logger = LoggerFactory.getLogger(SequenceVectorsTest.class);
2021-03-16 11:57:24 +09:00
@BeforeEach
2019-06-06 15:21:15 +03:00
public void setUp() throws Exception {
}
@Test
@Tag("long-running")
2019-06-06 15:21:15 +03:00
public void testAbstractW2VModel() throws Exception {
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
File file = resource.getFile();
logger.info("dtype: {}", Nd4j.dataType());
AbstractCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
/*
First we build line iterator
*/
BasicLineIterator underlyingIterator = new BasicLineIterator(file);
/*
Now we need the way to convert lines into Sequences of VocabWords.
In this example that's SentenceTransformer
*/
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
SentenceTransformer transformer =
new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
/*
And we pack that transformer into AbstractSequenceIterator
*/
AbstractSequenceIterator<VocabWord> sequenceIterator =
new AbstractSequenceIterator.Builder<>(transformer).build();
/*
Now we should build vocabulary out of sequence iterator.
We can skip this phase, and just set SequenceVectors.resetModel(TRUE), and vocabulary will be mastered internally
*/
VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>()
.addSource(sequenceIterator, 5).setTargetVocabCache(vocabCache).build();
constructor.buildJointVocabulary(false, true);
assertEquals(242, vocabCache.numWords());
assertEquals(634303, vocabCache.totalWordOccurrences());
VocabWord wordz = vocabCache.wordFor("day");
logger.info("Wordz: " + wordz);
/*
Time to build WeightLookupTable instance for our new model
*/
WeightLookupTable<VocabWord> lookupTable = new InMemoryLookupTable.Builder<VocabWord>().lr(0.025)
.vectorLength(150).useAdaGrad(false).cache(vocabCache).build();
/*
reset model is viable only if you're setting SequenceVectors.resetModel() to false
if set to True - it will be called internally
*/
lookupTable.resetWeights(true);
/*
Now we can build SequenceVectors model, that suits our needs
*/
SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration())
// minimum number of occurencies for each element in training corpus. All elements below this value will be ignored
// Please note: this value has effect only if resetModel() set to TRUE, for internal model building. Otherwise it'll be ignored, and actual vocabulary content will be used
.minWordFrequency(5)
// WeightLookupTable
.lookupTable(lookupTable)
// abstract iterator that covers training corpus
.iterate(sequenceIterator)
// vocabulary built prior to modelling
.vocabCache(vocabCache)
// we might want to set layer size here. otherwise it'll be derived from lookupTable
//.layerSize(150)
// batchSize is the number of sequences being processed by 1 thread at once
// this value actually matters if you have iterations > 1
.batchSize(250)
// number of iterations over batch
.iterations(1)
// number of iterations over whole training corpus
.epochs(1)
// if set to true, vocabulary will be built from scratches internally
// otherwise externally provided vocab will be used
.resetModel(false)
/*
These two methods define our training goals. At least one goal should be set to TRUE.
*/
.trainElementsRepresentation(true).trainSequencesRepresentation(false)
.build();
/*
Now, after all options are set, we just call fit()
*/
logger.info("Starting training...");
vectors.fit();
logger.info("Model saved...");
/*
As soon as fit() exits, model considered built, and we can test it.
Please note: all similarity context is handled via SequenceElement's labels, so if you're using SequenceVectors to build models for complex
objects/relations please take care of Labels uniqueness and meaning for yourself.
*/
double sim = vectors.similarity("day", "night");
logger.info("Day/night similarity: " + sim);
assertTrue(sim > 0.6d);
Collection<String> labels = vectors.wordsNearest("day", 10);
logger.info("Nearest labels to 'day': " + labels);
SequenceElementFactory<VocabWord> factory = new AbstractElementFactory<VocabWord>(VocabWord.class);
WordVectorSerializer.writeSequenceVectors(vectors, factory, "seqvec.mod");
SequenceVectors<VocabWord> model = WordVectorSerializer.readSequenceVectors(factory, new File("seqvec.mod"));
sim = model.similarity("day", "night");
logger.info("day/night similarity: " + sim);
}
@Test
@Tag("long-running")
2019-06-06 15:21:15 +03:00
public void testInternalVocabConstruction() throws Exception {
ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
File file = resource.getFile();
BasicLineIterator underlyingIterator = new BasicLineIterator(file);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
SentenceTransformer transformer =
new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
AbstractSequenceIterator<VocabWord> sequenceIterator =
new AbstractSequenceIterator.Builder<>(transformer).build();
SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration())
.minWordFrequency(5).iterate(sequenceIterator).batchSize(250).iterations(1).epochs(1)
.resetModel(false).trainElementsRepresentation(true).build();
logger.info("Fitting model...");
vectors.fit();
logger.info("Model ready...");
double sim = vectors.similarity("day", "night");
logger.info("Day/night similarity: " + sim);
assertTrue(sim > 0.6d);
Collection<String> labels = vectors.wordsNearest("day", 10);
logger.info("Nearest labels to 'day': " + labels);
}
@Test
public void testElementsLearningAlgo1() throws Exception {
SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration())
.minWordFrequency(5).batchSize(250).iterations(1)
.elementsLearningAlgorithm(
"org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram")
.epochs(1).resetModel(false).trainElementsRepresentation(true).build();
}
@Test
public void testSequenceLearningAlgo1() throws Exception {
SequenceVectors<VocabWord> vectors =
new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration()).minWordFrequency(5)
.batchSize(250).iterations(1)
.sequenceLearningAlgorithm(
"org.deeplearning4j.models.embeddings.learning.impl.sequence.DBOW")
.epochs(1).resetModel(false).trainElementsRepresentation(false).build();
}
@Test
//@Ignore
2019-06-06 15:21:15 +03:00
public void testDeepWalk() throws Exception {
Heartbeat.getInstance().disableHeartbeat();
AbstractCache<Blogger> vocabCache = new AbstractCache.Builder<Blogger>().build();
Graph<Blogger, Double> graph = buildGraph();
GraphWalker<Blogger> walker = new PopularityWalker.Builder<>(graph)
.setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED).setWalkLength(40)
.setWalkDirection(WalkDirection.FORWARD_UNIQUE).setRestartProbability(0.05)
.setPopularitySpread(10).setPopularityMode(PopularityMode.MAXIMUM)
.setSpreadSpectrum(SpreadSpectrum.PROPORTIONAL).build();
/*
GraphWalker<Blogger> walker = new RandomWalker.Builder<Blogger>(graph)
.setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED)
.setWalkLength(40)
.setWalkDirection(WalkDirection.RANDOM)
.setRestartProbability(0.05)
.build();
*/
GraphTransformer<Blogger> graphTransformer = new GraphTransformer.Builder<>(graph).setGraphWalker(walker)
.shuffleOnReset(true).setVocabCache(vocabCache).build();
Blogger blogger = graph.getVertex(0).getValue();
assertEquals(119, blogger.getElementFrequency(), 0.001);
logger.info("Blogger: " + blogger);
AbstractSequenceIterator<Blogger> sequenceIterator =
new AbstractSequenceIterator.Builder<>(graphTransformer).build();
WeightLookupTable<Blogger> lookupTable = new InMemoryLookupTable.Builder<Blogger>().lr(0.025).vectorLength(150)
.useAdaGrad(false).cache(vocabCache).seed(42).build();
lookupTable.resetWeights(true);
SequenceVectors<Blogger> vectors = new SequenceVectors.Builder<Blogger>(new VectorsConfiguration())
// WeightLookupTable
.lookupTable(lookupTable)
// abstract iterator that covers training corpus
.iterate(sequenceIterator)
// vocabulary built prior to modelling
.vocabCache(vocabCache)
// batchSize is the number of sequences being processed by 1 thread at once
// this value actually matters if you have iterations > 1
.batchSize(1000)
// number of iterations over batch
.iterations(1)
// number of iterations over whole training corpus
.epochs(10)
// if set to true, vocabulary will be built from scratches internally
// otherwise externally provided vocab will be used
.resetModel(false)
/*
These two methods define our training goals. At least one goal should be set to TRUE.
*/
.trainElementsRepresentation(true).trainSequencesRepresentation(false)
/*
Specifies elements learning algorithms. SkipGram, for example.
*/
.elementsLearningAlgorithm(new SkipGram<Blogger>())
.learningRate(0.025)
.layerSize(150)
.sampling(0)
.negativeSample(0)
.windowSize(4)
.workers(6)
.seed(42)
.build();
vectors.fit();
vectors.setModelUtils(new FlatModelUtils());
// logger.info("12: " + Arrays.toString(vectors.getWordVector("12")));
double sim = vectors.similarity("12", "72");
Collection<String> list = vectors.wordsNearest("12", 20);
logger.info("12->72: " + sim);
printWords("12", list, vectors);
assertTrue(sim > 0.10);
assertFalse(Double.isNaN(sim));
}
private List<Blogger> getBloggersFromGraph(Graph<Blogger, Double> graph) {
List<Blogger> result = new ArrayList<>();
List<Vertex<Blogger>> bloggers = graph.getVertices(0, graph.numVertices() - 1);
for (Vertex<Blogger> vertex : bloggers) {
result.add(vertex.getValue());
}
return result;
}
private static Graph<Blogger, Double> buildGraph() throws IOException, InterruptedException {
File nodes = new File("/ext/Temp/BlogCatalog/nodes.csv");
CSVRecordReader reader = new CSVRecordReader(0, ',');
reader.initialize(new FileSplit(nodes));
List<Blogger> bloggers = new ArrayList<>();
int cnt = 0;
while (reader.hasNext()) {
List<Writable> lines = new ArrayList<>(reader.next());
Blogger blogger = new Blogger(lines.get(0).toInt());
bloggers.add(blogger);
cnt++;
}
reader.close();
Graph<Blogger, Double> graph = new Graph<>(bloggers, true);
// load edges
File edges = new File("/ext/Temp/BlogCatalog/edges.csv");
reader = new CSVRecordReader(0, ',');
reader.initialize(new FileSplit(edges));
while (reader.hasNext()) {
List<Writable> lines = new ArrayList<>(reader.next());
int from = lines.get(0).toInt();
int to = lines.get(1).toInt();
graph.addEdge(from - 1, to - 1, 1.0, false);
}
logger.info("Connected on 0: [" + graph.getConnectedVertices(0).size() + "]");
logger.info("Connected on 1: [" + graph.getConnectedVertices(1).size() + "]");
logger.info("Connected on 3: [" + graph.getConnectedVertices(3).size() + "]");
assertEquals(119, graph.getConnectedVertices(0).size());
assertEquals(9, graph.getConnectedVertices(1).size());
assertEquals(6, graph.getConnectedVertices(3).size());
return graph;
}
@Data
private static class Blogger extends SequenceElement {
@Getter
@Setter
private int id;
public Blogger() {
super();
}
public Blogger(int id) {
super();
this.id = id;
}
/**
* This method should return string representation of this SequenceElement, so it can be used for
*
* @return
*/
@Override
public String getLabel() {
return String.valueOf(id);
}
/**
* @return
*/
@Override
public String toJSON() {
return null;
}
@Override
public String toString() {
return "VocabWord{" + "wordFrequency=" + this.elementFrequency + ", index=" + index + ", codes=" + codes
2022-10-21 15:19:32 +02:00
+ ", word='" + id + '\'' + ", points=" + points + ", codeLength="
2019-06-06 15:21:15 +03:00
+ codeLength + '}';
}
}
private static void printWords(String target, Collection<String> list, SequenceVectors vec) {
System.out.println("Words close to [" + target + "]: ");
for (String word : list) {
double sim = vec.similarity(target, word);
System.out.print("'" + word + "': [" + sim + "], ");
}
System.out.print("\n");
}
}