diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/ElementsLearningAlgorithm.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/ElementsLearningAlgorithm.java index b684a40a3..f856fc6ea 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/ElementsLearningAlgorithm.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/ElementsLearningAlgorithm.java @@ -27,7 +27,7 @@ import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import java.util.concurrent.atomic.AtomicLong; /** - * Implementations of this interface should contain element-related learning algorithms. Like skip-gram, cbow or glove + * Implementations of this interface should contain element-related learning algorithms. Like skip-gram or cbow * * @author raver119@gmail.com */ diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe.java deleted file mode 100644 index a655cfbde..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/GloVe.java +++ /dev/null @@ -1,427 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.embeddings.learning.impl.elements; - -import lombok.NonNull; -import org.deeplearning4j.models.embeddings.WeightLookupTable; -import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; -import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm; -import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration; -import org.deeplearning4j.models.glove.AbstractCoOccurrences; -import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator; -import org.deeplearning4j.models.sequencevectors.sequence.Sequence; -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.learning.legacy.AdaGrad; -import org.nd4j.common.primitives.Counter; -import org.nd4j.common.primitives.Pair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; - -/** - * GloVe LearningAlgorithm implementation for SequenceVectors - * - * - * @author raver119@gmail.com - */ -public class GloVe implements ElementsLearningAlgorithm { - - private VocabCache vocabCache; - private AbstractCoOccurrences coOccurrences; - private WeightLookupTable lookupTable; - private VectorsConfiguration configuration; - - private AtomicBoolean isTerminate = new AtomicBoolean(false); - - private INDArray syn0; - - private double xMax; - private boolean shuffle; - private boolean symmetric; - protected double alpha = 0.75d; - protected double learningRate = 0.0d; - protected int maxmemory = 0; - protected int batchSize = 1000; - - private AdaGrad weightAdaGrad; - private AdaGrad biasAdaGrad; - private INDArray bias; - - private int workers = Runtime.getRuntime().availableProcessors(); - - private int vectorLength; - - private static final Logger log = LoggerFactory.getLogger(GloVe.class); - - @Override - public String getCodeName() { - return "GloVe"; - } - - @Override - public void finish() { - log.info("GloVe finalizer..."); - } - - @Override - public void configure(@NonNull VocabCache vocabCache, @NonNull WeightLookupTable lookupTable, - @NonNull VectorsConfiguration configuration) { - this.vocabCache = vocabCache; - this.lookupTable = lookupTable; - this.configuration = configuration; - - this.syn0 = ((InMemoryLookupTable) lookupTable).getSyn0(); - - - this.vectorLength = configuration.getLayersSize(); - - if (this.learningRate == 0.0d) - this.learningRate = configuration.getLearningRate(); - - - - weightAdaGrad = new AdaGrad(new long[] {this.vocabCache.numWords() + 1, vectorLength}, learningRate); - bias = Nd4j.create(syn0.rows()); - - biasAdaGrad = new AdaGrad(bias.shape(), this.learningRate); - - // maxmemory = Runtime.getRuntime().maxMemory() - (vocabCache.numWords() * vectorLength * 2 * 8); - - log.info("GloVe params: {Max Memory: [" + maxmemory + "], Learning rate: [" + this.learningRate + "], Alpha: [" - + alpha + "], xMax: [" + xMax + "], Symmetric: [" + symmetric + "], Shuffle: [" + shuffle - + "]}"); - } - - /** - * pretrain is used to build CoOccurrence matrix for GloVe algorithm - * @param iterator - */ - @Override - public void pretrain(@NonNull SequenceIterator iterator) { - // CoOccurence table should be built here - coOccurrences = new AbstractCoOccurrences.Builder() - // TODO: symmetric should be handled via VectorsConfiguration - .symmetric(this.symmetric).windowSize(configuration.getWindow()).iterate(iterator) - .workers(workers).vocabCache(vocabCache).maxMemory(maxmemory).build(); - - coOccurrences.fit(); - } - - public double learnSequence(Sequence sequence, AtomicLong nextRandom, double learningRate, - BatchSequences batchSequences) { - throw new UnsupportedOperationException(); - } - /** - * Learns sequence using GloVe algorithm - * - * @param sequence - * @param nextRandom - * @param learningRate - */ - @Override - public synchronized double learnSequence(@NonNull Sequence sequence, @NonNull AtomicLong nextRandom, - double learningRate) { - /* - GloVe learning algorithm is implemented like a hack over settled ElementsLearningAlgorithm mechanics. It's called in SequenceVectors context, but actually only for the first call. - All subsequent calls will met early termination condition, and will be successfully ignored. But since elements vectors will be updated within first call, - this will allow compatibility with everything beyond this implementaton - */ - if (isTerminate.get()) - return 0; - - final AtomicLong pairsCount = new AtomicLong(0); - final Counter errorCounter = new Counter<>(); - - //List> coList = coOccurrences.coOccurrenceList(); - - for (int i = 0; i < configuration.getEpochs(); i++) { - - // TODO: shuffle should be built in another way. - //if (shuffle) - //Collections.shuffle(coList); - - Iterator, Double>> pairs = coOccurrences.iterator(); - - List threads = new ArrayList<>(); - for (int x = 0; x < workers; x++) { - threads.add(x, new GloveCalculationsThread(i, x, pairs, pairsCount, errorCounter)); - threads.get(x).start(); - } - - - - for (int x = 0; x < workers; x++) { - try { - threads.get(x).join(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - log.info("Processed [" + pairsCount.get() + "] pairs, Error was [" + errorCounter.getCount(i) + "]"); - } - - isTerminate.set(true); - return 0; - } - - /** - * Since GloVe is learning representations using elements CoOccurences, all training is done in GloVe class internally, so only first thread will execute learning process, - * and the rest of parent threads will just exit learning process - * - * @return True, if training should stop, False otherwise. - */ - @Override - public synchronized boolean isEarlyTerminationHit() { - return isTerminate.get(); - } - - private double iterateSample(T element1, T element2, double score) { - //prediction: input + bias - if (element1.getIndex() < 0 || element1.getIndex() >= syn0.rows()) - throw new IllegalArgumentException("Illegal index for word " + element1.getLabel()); - if (element2.getIndex() < 0 || element2.getIndex() >= syn0.rows()) - throw new IllegalArgumentException("Illegal index for word " + element2.getLabel()); - - INDArray w1Vector = syn0.slice(element1.getIndex()); - INDArray w2Vector = syn0.slice(element2.getIndex()); - - - //w1 * w2 + bias - double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector); - prediction += bias.getDouble(element1.getIndex()) + bias.getDouble(element2.getIndex()) - Math.log(score); - - double fDiff = (score > xMax) ? prediction : Math.pow(score / xMax, alpha) * prediction; // Math.pow(Math.min(1.0,(score / maxCount)),xMax); - - // double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score)); - - if (Double.isNaN(fDiff)) - fDiff = Nd4j.EPS_THRESHOLD; - //amount of change - double gradient = fDiff * learningRate; - - //note the update step here: the gradient is - //the gradient of the OPPOSITE word - //for adagrad we will use the index of the word passed in - //for the gradient calculation we will use the context vector - update(element1, w1Vector, w2Vector, gradient); - update(element2, w2Vector, w1Vector, gradient); - return 0.5 * fDiff * prediction; - } - - private void update(T element1, INDArray wordVector, INDArray contextVector, double gradient) { - //gradient for word vectors - INDArray grad1 = contextVector.mul(gradient); - INDArray update = weightAdaGrad.getGradient(grad1, element1.getIndex(), syn0.shape()); - - //update vector - wordVector.subi(update); - - double w1Bias = bias.getDouble(element1.getIndex()); - double biasGradient = biasAdaGrad.getGradient(gradient, element1.getIndex(), bias.shape()); - double update2 = w1Bias - biasGradient; - bias.putScalar(element1.getIndex(), update2); - } - - private class GloveCalculationsThread extends Thread implements Runnable { - private final int threadId; - private final int epochId; - // private final AbstractCoOccurrences coOccurrences; - private final Iterator, Double>> coList; - - private final AtomicLong pairsCounter; - private final Counter errorCounter; - - public GloveCalculationsThread(int epochId, int threadId, @NonNull Iterator, Double>> pairs, - @NonNull AtomicLong pairsCounter, @NonNull Counter errorCounter) { - this.epochId = epochId; - this.threadId = threadId; - // this.coOccurrences = coOccurrences; - - this.pairsCounter = pairsCounter; - this.errorCounter = errorCounter; - - coList = pairs; - - this.setName("GloVe ELA t." + this.threadId); - } - - @Override - public void run() { - // int startPosition = threadId * (coList.size() / workers); - // int stopPosition = (threadId + 1) * (coList.size() / workers); - // log.info("Total size: [" + coList.size() + "], thread start: [" + startPosition + "], thread stop: [" + stopPosition + "]"); - while (coList.hasNext()) { - - // now we fetch pairs into batch - List, Double>> pairs = new ArrayList<>(); - int cnt = 0; - while (coList.hasNext() && cnt < batchSize) { - pairs.add(coList.next()); - cnt++; - } - - if (shuffle) - Collections.shuffle(pairs); - - Iterator, Double>> iterator = pairs.iterator(); - - while (iterator.hasNext()) { - // now for each pair do appropriate training - Pair, Double> pairDoublePair = iterator.next(); - - // That's probably ugly and probably should be improved somehow - - T element1 = pairDoublePair.getFirst().getFirst(); - T element2 = pairDoublePair.getFirst().getSecond(); - double weight = pairDoublePair.getSecond(); //coOccurrences.getCoOccurrenceCount(element1, element2); - if (weight <= 0) { - // log.warn("Skipping pair ("+ element1.getLabel()+", " + element2.getLabel()+")"); - pairsCounter.incrementAndGet(); - continue; - } - - errorCounter.incrementCount(epochId, iterateSample(element1, element2, weight)); - if (pairsCounter.incrementAndGet() % 1000000 == 0) { - log.info("Processed [" + pairsCounter.get() + "] word pairs so far..."); - } - } - - } - } - } - - public static class Builder { - - protected double xMax = 100.0d; - protected double alpha = 0.75d; - protected double learningRate = 0.0d; - - protected boolean shuffle = false; - protected boolean symmetric = false; - protected int maxmemory = 0; - - protected int batchSize = 1000; - - public Builder() { - - } - - /** - * This parameter specifies batch size for each thread. Also, if shuffle == TRUE, this batch will be shuffled before processing. Default value: 1000; - * - * @param batchSize - * @return - */ - public Builder batchSize(int batchSize) { - this.batchSize = batchSize; - return this; - } - - - /** - * Initial learning rate; default 0.05 - * - * @param eta - * @return - */ - public Builder learningRate(double eta) { - this.learningRate = eta; - return this; - } - - /** - * Parameter in exponent of weighting function; default 0.75 - * - * @param alpha - * @return - */ - public Builder alpha(double alpha) { - this.alpha = alpha; - return this; - } - - /** - * This method allows you to specify maximum memory available for CoOccurrence map builder. - * - * Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm. - * Please note: this option won't override -Xmx JVM value. - * - * @param gbytes memory limit, in gigabytes - * @return - */ - public Builder maxMemory(int gbytes) { - this.maxmemory = gbytes; - return this; - } - - /** - * Parameter specifying cutoff in weighting function; default 100.0 - * - * @param xMax - * @return - */ - public Builder xMax(double xMax) { - this.xMax = xMax; - return this; - } - - /** - * Parameter specifying, if cooccurrences list should be shuffled between training epochs - * - * @param reallyShuffle - * @return - */ - public Builder shuffle(boolean reallyShuffle) { - this.shuffle = reallyShuffle; - return this; - } - - /** - * Parameters specifying, if cooccurrences list should be build into both directions from any current word. - * - * @param reallySymmetric - * @return - */ - public Builder symmetric(boolean reallySymmetric) { - this.symmetric = reallySymmetric; - return this; - } - - public GloVe build() { - GloVe ret = new GloVe<>(); - ret.symmetric = this.symmetric; - ret.shuffle = this.shuffle; - ret.xMax = this.xMax; - ret.alpha = this.alpha; - ret.learningRate = this.learningRate; - ret.maxmemory = this.maxmemory; - ret.batchSize = this.batchSize; - - return ret; - } - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java index 8f0003728..136143d79 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java @@ -24,6 +24,7 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.io.LineIterator; import org.apache.commons.io.output.CloseShieldOutputStream; +import org.deeplearning4j.common.util.DL4JFileUtils; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; @@ -32,7 +33,6 @@ import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl; import org.deeplearning4j.models.fasttext.FastText; -import org.deeplearning4j.models.glove.Glove; import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; import org.deeplearning4j.models.sequencevectors.SequenceVectors; import org.deeplearning4j.models.sequencevectors.interfaces.SequenceElementFactory; @@ -50,19 +50,18 @@ import org.deeplearning4j.text.documentiterator.LabelsSource; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.deeplearning4j.common.util.DL4JFileUtils; +import org.nd4j.common.primitives.Pair; +import org.nd4j.common.util.OneTimeLogger; import org.nd4j.compression.impl.NoOp; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; -import org.nd4j.common.primitives.Pair; import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.SerializationFeature; import org.nd4j.storage.CompressedRamStorage; -import org.nd4j.common.util.OneTimeLogger; import java.io.*; import java.nio.charset.StandardCharsets; @@ -108,10 +107,6 @@ import java.util.zip.*; * {@link #readParagraphVectors(String)} * {@link #readParagraphVectors(InputStream)} * - *
  • Serializers for GloVe:
  • - * {@link #writeWordVectors(Glove, File)} - * {@link #writeWordVectors(Glove, String)} - * {@link #writeWordVectors(Glove, OutputStream)} * *
  • Adapters
  • * {@link #fromTableAndVocab(WeightLookupTable, VocabCache)} @@ -119,7 +114,6 @@ import java.util.zip.*; * {@link #loadTxt(File)} * *
  • Serializers to tSNE format
  • - * {@link #writeTsneFormat(Glove, INDArray, File)} * {@link #writeTsneFormat(Word2Vec, INDArray, File)} * *
  • FastText serializer:
  • @@ -1114,48 +1108,6 @@ public class WordVectorSerializer { } } - /** - * This method saves GloVe model to the given output stream. - * - * @param vectors GloVe model to be saved - * @param file path where model should be saved to - */ - public static void writeWordVectors(@NonNull Glove vectors, @NonNull File file) { - try (BufferedOutputStream fos = new BufferedOutputStream(new FileOutputStream(file))) { - writeWordVectors(vectors, fos); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - /** - * This method saves GloVe model to the given output stream. - * - * @param vectors GloVe model to be saved - * @param path path where model should be saved to - */ - public static void writeWordVectors(@NonNull Glove vectors, @NonNull String path) { - try (BufferedOutputStream fos = new BufferedOutputStream(new FileOutputStream(path))) { - writeWordVectors(vectors, fos); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - /** - * This method saves GloVe model to the given OutputStream - * - * @param vectors GloVe model to be saved - * @param stream OutputStream where model should be saved to - */ - public static void writeWordVectors(@NonNull Glove vectors, @NonNull OutputStream stream) { - try { - writeWordVectors(vectors.lookupTable(), stream); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - /** * This method saves paragraph vectors to the given output stream. * @@ -1818,43 +1770,6 @@ public class WordVectorSerializer { return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache)); } - /** - * Write the tsne format - * - * @param vec the word vectors to use for labeling - * @param tsne the tsne array to write - * @param csv the file to use - * @throws Exception - */ - public static void writeTsneFormat(Glove vec, INDArray tsne, File csv) throws Exception { - try (BufferedWriter write = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(csv), StandardCharsets.UTF_8))) { - int words = 0; - InMemoryLookupCache l = (InMemoryLookupCache) vec.vocab(); - for (String word : vec.vocab().words()) { - if (word == null) { - continue; - } - StringBuilder sb = new StringBuilder(); - INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex()); - for (int j = 0; j < wordVector.length(); j++) { - sb.append(wordVector.getDouble(j)); - if (j < wordVector.length() - 1) { - sb.append(","); - } - } - sb.append(","); - sb.append(word.replaceAll(" ", WHITESPACE_REPLACEMENT)); - sb.append(" "); - - sb.append("\n"); - write.write(sb.toString()); - - } - - log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize()); - } - } - /** * Write the tsne format * diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/AbstractCoOccurrences.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/AbstractCoOccurrences.java deleted file mode 100644 index 969dbaeb9..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/AbstractCoOccurrences.java +++ /dev/null @@ -1,652 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove; - -import lombok.NonNull; -import org.deeplearning4j.models.glove.count.*; -import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator; -import org.deeplearning4j.models.sequencevectors.iterators.FilteredSequenceIterator; -import org.deeplearning4j.models.sequencevectors.iterators.SynchronizedSequenceIterator; -import org.deeplearning4j.models.sequencevectors.sequence.Sequence; -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; -import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator; -import org.deeplearning4j.text.sentenceiterator.SentenceIterator; -import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator; -import org.deeplearning4j.common.util.DL4JFileUtils; -import org.nd4j.common.util.ThreadUtils; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.locks.ReentrantReadWriteLock; - -/** - * This class implements building cooccurrence map for abstract training corpus. - * However it's performance rather low, due to exsessive IO that happens in ShadowCopyThread - * - * PLEASE NOTE: Current implementation involves massive IO, and it should be rewritter as soon as ND4j gets sparse arrays support - * - * @author raver119@gmail.com - */ -public class AbstractCoOccurrences implements Serializable { - - protected boolean symmetric; - protected int windowSize; - protected VocabCache vocabCache; - protected SequenceIterator sequenceIterator; - - // please note, we need enough room for ShadowCopy thread, that's why -1 there - protected int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1); - - // target file, where text with cooccurrencies should be saved - protected File targetFile; - - protected ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); - - protected long memory_threshold = 0; - - private ShadowCopyThread shadowThread; - - // private Counter sentenceOccurrences = Util.parallelCounter(); - //private CounterMap coOccurrenceCounts = Util.parallelCounterMap(); - private volatile CountMap coOccurrenceCounts = new CountMap<>(); - //private Counter occurrenceAllocations = Util.parallelCounter(); - //private List> coOccurrences; - private AtomicLong processedSequences = new AtomicLong(0); - - - protected static final Logger logger = LoggerFactory.getLogger(AbstractCoOccurrences.class); - - // this method should be private, to avoid non-configured instantiation - private AbstractCoOccurrences() {} - - /** - * This method returns cooccurrence distance weights for two SequenceElements - * - * @param element1 - * @param element2 - * @return distance weight - */ - public double getCoOccurrenceCount(@NonNull T element1, @NonNull T element2) { - return coOccurrenceCounts.getCount(element1, element2); - } - - /** - * This method returns estimated memory footrpint, based on current CountMap content - * @return - */ - protected long getMemoryFootprint() { - // TODO: implement this method. It should return approx. memory used by appropriate CountMap - try { - lock.readLock().lock(); - return ((long) coOccurrenceCounts.size()) * 24L * 5L; - } finally { - lock.readLock().unlock(); - } - } - - /** - * This memory returns memory threshold, defined as 1/2 of memory allowed for allocation - * @return - */ - protected long getMemoryThreshold() { - return memory_threshold / 2L; - } - - public void fit() { - shadowThread = new ShadowCopyThread(); - shadowThread.start(); - - // we should reset iterator before counting cooccurrences - sequenceIterator.reset(); - - List threads = new ArrayList<>(); - for (int x = 0; x < workers; x++) { - threads.add(x, new CoOccurrencesCalculatorThread(x, new FilteredSequenceIterator<>( - new SynchronizedSequenceIterator<>(sequenceIterator), vocabCache), processedSequences)); - threads.get(x).start(); - } - - for (int x = 0; x < workers; x++) { - try { - threads.get(x).join(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - shadowThread.finish(); - logger.info("CoOccurrences map was built."); - } - - /** - * - * This method returns iterator with elements pairs and their weights. Resulting iterator is safe to use in multi-threaded environment. - * - * Developer's note: thread safety on received iterator is delegated to PrefetchedSentenceIterator - * @return - */ - public Iterator, Double>> iterator() { - final SentenceIterator iterator; - - try { - iterator = new SynchronizedSentenceIterator( - new PrefetchingSentenceIterator.Builder(new BasicLineIterator(targetFile)) - .setFetchSize(500000).build()); - - } catch (Exception e) { - logger.error("Target file was not found on last stage!"); - throw new RuntimeException(e); - } - return new Iterator, Double>>() { - /* - iterator should be built on top of current text file with all pairs - */ - - @Override - public boolean hasNext() { - return iterator.hasNext(); - } - - @Override - public Pair, Double> next() { - String line = iterator.nextSentence(); - String[] strings = line.split(" "); - - T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0])); - T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1])); - Double weight = Double.valueOf(strings[2]); - - return new Pair<>(new Pair<>(element1, element2), weight); - } - - @Override - public void remove() { - throw new UnsupportedOperationException("remove() method can't be supported on read-only interface"); - } - }; - } - - public static class Builder { - - protected boolean symmetric; - protected int windowSize = 5; - protected VocabCache vocabCache; - protected SequenceIterator sequenceIterator; - protected int workers = Runtime.getRuntime().availableProcessors(); - protected File target; - protected long maxmemory = Runtime.getRuntime().maxMemory(); - - public Builder() { - - } - - public Builder symmetric(boolean reallySymmetric) { - this.symmetric = reallySymmetric; - return this; - } - - public Builder windowSize(int windowSize) { - this.windowSize = windowSize; - return this; - } - - public Builder vocabCache(@NonNull VocabCache cache) { - this.vocabCache = cache; - return this; - } - - public Builder iterate(@NonNull SequenceIterator iterator) { - this.sequenceIterator = new SynchronizedSequenceIterator<>(iterator); - return this; - } - - public Builder workers(int numWorkers) { - this.workers = numWorkers; - return this; - } - - /** - * This method allows you to specify maximum memory available for CoOccurrence map builder. - * - * Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm. - * Please note: this option won't override -Xmx JVM value. - * - * @param gbytes memory available, in GigaBytes - * @return - */ - public Builder maxMemory(int gbytes) { - if (gbytes > 0) { - this.maxmemory = Math.max(gbytes - 1, 1) * 1024 * 1024 * 1024L; - } - - return this; - } - - /** - * Path to save cooccurrence map after construction. - * If targetFile is not specified, temporary file will be used. - * - * @param path - * @return - */ - public Builder targetFile(@NonNull String path) { - this.targetFile(new File(path)); - return this; - } - - /** - * Path to save cooccurrence map after construction. - * If targetFile is not specified, temporary file will be used. - * - * @param file - * @return - */ - public Builder targetFile(@NonNull File file) { - this.target = file; - return this; - } - - public AbstractCoOccurrences build() { - AbstractCoOccurrences ret = new AbstractCoOccurrences<>(); - ret.sequenceIterator = this.sequenceIterator; - ret.windowSize = this.windowSize; - ret.vocabCache = this.vocabCache; - ret.symmetric = this.symmetric; - ret.workers = this.workers; - - if (this.maxmemory < 1) { - this.maxmemory = Runtime.getRuntime().maxMemory(); - } - ret.memory_threshold = this.maxmemory; - - - logger.info("Actual memory limit: [" + this.maxmemory + "]"); - - // use temp file, if no target file was specified - try { - if (this.target == null) { - this.target = DL4JFileUtils.createTempFile("cooccurrence", "map"); - } - this.target.deleteOnExit(); - } catch (Exception e) { - throw new RuntimeException(e); - } - - ret.targetFile = this.target; - - return ret; - } - } - - private class CoOccurrencesCalculatorThread extends Thread implements Runnable { - - private final SequenceIterator iterator; - private final AtomicLong sequenceCounter; - private int threadId; - - public CoOccurrencesCalculatorThread(int threadId, @NonNull SequenceIterator iterator, - @NonNull AtomicLong sequenceCounter) { - this.iterator = iterator; - this.sequenceCounter = sequenceCounter; - this.threadId = threadId; - - this.setName("CoOccurrencesCalculatorThread " + threadId); - } - - @Override - public void run() { - while (iterator.hasMoreSequences()) { - Sequence sequence = iterator.nextSequence(); - - List tokens = new ArrayList<>(sequence.asLabels()); - // logger.info("Tokens size: " + tokens.size()); - for (int x = 0; x < sequence.getElements().size(); x++) { - int wordIdx = vocabCache.indexOf(tokens.get(x)); - if (wordIdx < 0) { - continue; - } - String w1 = vocabCache.wordFor(tokens.get(x)).getLabel(); - - // THIS iS SAFE TO REMOVE, NO CHANCE WE'll HAVE UNK WORD INSIDE SEQUENCE - /*if(w1.equals(Glove.UNK)) - continue; - */ - - int windowStop = Math.min(x + windowSize + 1, tokens.size()); - for (int j = x; j < windowStop; j++) { - int otherWord = vocabCache.indexOf(tokens.get(j)); - if (otherWord < 0) { - continue; - } - String w2 = vocabCache.wordFor(tokens.get(j)).getLabel(); - - if (w2.equals(Glove.DEFAULT_UNK) || otherWord == wordIdx) { - continue; - } - - - T tokenX = vocabCache.wordFor(tokens.get(x)); - T tokenJ = vocabCache.wordFor(tokens.get(j)); - double nWeight = 1.0 / (j - x + Nd4j.EPS_THRESHOLD); - - while (getMemoryFootprint() >= getMemoryThreshold()) { - shadowThread.invoke(); - /*lock.readLock().lock(); - int size = coOccurrenceCounts.size(); - lock.readLock().unlock(); - */ - if (threadId == 0) { - logger.debug("Memory consuimption > threshold: {footrpint: [" + getMemoryFootprint() - + "], threshold: [" + getMemoryThreshold() + "] }"); - } - ThreadUtils.uncheckedSleep(10000); - } - /* - if (getMemoryFootprint() == 0) { - logger.info("Zero size!"); - } - */ - - try { - lock.readLock().lock(); - if (wordIdx < otherWord) { - coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight); - if (symmetric) { - coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight); - } - } else { - coOccurrenceCounts.incrementCount(tokenJ, tokenX, nWeight); - - if (symmetric) { - coOccurrenceCounts.incrementCount(tokenX, tokenJ, nWeight); - } - } - } finally { - lock.readLock().unlock(); - } - } - } - - sequenceCounter.incrementAndGet(); - } - } - } - - /** - * This class is designed to provide shadow copy functionality for CoOccurence maps, since with proper corpus size you can't fit such a map into memory - * - */ - private class ShadowCopyThread extends Thread implements Runnable { - - private AtomicBoolean isFinished = new AtomicBoolean(false); - private AtomicBoolean isTerminate = new AtomicBoolean(false); - private AtomicBoolean isInvoked = new AtomicBoolean(false); - private AtomicBoolean shouldInvoke = new AtomicBoolean(false); - - // file that contains resuts from previous runs - private File[] tempFiles; - private RoundCount counter; - - public ShadowCopyThread() { - try { - - counter = new RoundCount(1); - tempFiles = new File[2]; - - tempFiles[0] = DL4JFileUtils.createTempFile("aco", "tmp"); - tempFiles[1] = DL4JFileUtils.createTempFile("aco", "tmp"); - - tempFiles[0].deleteOnExit(); - tempFiles[1].deleteOnExit(); - } catch (Exception e) { - throw new RuntimeException(e); - } - - this.setName("ACO ShadowCopy thread"); - } - - @Override - public void run() { - /* - Basic idea is pretty simple: run quetly, untill memory gets filled up to some high volume. - As soon as this happens - execute shadow copy. - */ - while (!isFinished.get() && !isTerminate.get()) { - // check used memory. if memory use below threshold - sleep for a while. if above threshold - invoke copier - - if (getMemoryFootprint() > getMemoryThreshold() || (shouldInvoke.get() && !isInvoked.get())) { - // we'll just invoke copier, nothing else - shouldInvoke.compareAndSet(true, false); - invokeBlocking(); - } else { - /* - commented and left here for future debugging purposes, if needed - - //lock.readLock().lock(); - //int size = coOccurrenceCounts.size(); - //lock.readLock().unlock(); - //logger.info("Current memory situation: {size: [" +size+ "], footprint: [" + getMemoryFootprint()+"], threshold: ["+ getMemoryThreshold() +"]}"); - */ - ThreadUtils.uncheckedSleep(1000); - } - } - } - - /** - * This methods advises shadow copy process to start - */ - public void invoke() { - shouldInvoke.compareAndSet(false, true); - } - - /** - * This methods dumps cooccurrence map into save file. - * Please note: this method is synchronized and will block, until complete - */ - public synchronized void invokeBlocking() { - if (getMemoryFootprint() < getMemoryThreshold() && !isFinished.get()) { - return; - } - - int numberOfLinesSaved = 0; - - isInvoked.set(true); - - logger.debug("Memory purge started."); - - /* - Basic plan: - 1. Open temp file - 2. Read that file line by line - 3. For each read line do synchronization in memory > new file direction - */ - - counter.tick(); - - CountMap localMap; - try { - // in any given moment there's going to be only 1 WriteLock, due to invokeBlocking() being synchronized call - lock.writeLock().lock(); - - - - // obtain local copy of CountMap - localMap = coOccurrenceCounts; - - // set new CountMap, and release write lock - coOccurrenceCounts = new CountMap<>(); - } finally { - lock.writeLock().unlock(); - } - - try { - - File file = null; - if (!isFinished.get()) { - file = tempFiles[counter.previous()]; - } else - file = targetFile; - - - // PrintWriter pw = new PrintWriter(file); - - int linesRead = 0; - - logger.debug("Saving to: [" + counter.get() + "], Reading from: [" + counter.previous() + "]"); - CoOccurenceReader reader = - new BinaryCoOccurrenceReader<>(tempFiles[counter.previous()], vocabCache, localMap); - CoOccurrenceWriter writer = (isFinished.get()) ? new ASCIICoOccurrenceWriter(targetFile) - : new BinaryCoOccurrenceWriter(tempFiles[counter.get()]); - while (reader.hasMoreObjects()) { - CoOccurrenceWeight line = reader.nextObject(); - - if (line != null) { - writer.writeObject(line); - numberOfLinesSaved++; - linesRead++; - } - } - reader.finish(); - - logger.debug("Lines read: [" + linesRead + "]"); - - //now, we can dump the rest of elements, which were not presented in existing dump - Iterator> iterator = localMap.getPairIterator(); - while (iterator.hasNext()) { - Pair pair = iterator.next(); - double mWeight = localMap.getCount(pair); - CoOccurrenceWeight object = new CoOccurrenceWeight<>(); - object.setElement1(pair.getFirst()); - object.setElement2(pair.getSecond()); - object.setWeight(mWeight); - - writer.writeObject(object); - - numberOfLinesSaved++; - // if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]"); - } - - writer.finish(); - - /* - SentenceIterator sIterator = new PrefetchingSentenceIterator.Builder(new BasicLineIterator(tempFiles[counter.get()])) - .setFetchSize(500000) - .build(); - - - int linesRead = 0; - while (sIterator.hasNext()) { - //List list = new ArrayList<>(reader.next()); - String sentence = sIterator.nextSentence(); - if (sentence == null || sentence.isEmpty()) continue; - String[] strings = sentence.split(" "); - - - // first two elements are integers - vocab indexes - //T element1 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(0).toInt())); - //T element2 = vocabCache.wordFor(vocabCache.wordAtIndex(list.get(1).toInt())); - T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0])); - T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1])); - - // getting third element, previously stored weight - double sWeight = Double.valueOf(strings[2]); // list.get(2).toDouble(); - - // now, since we have both elements ready, we can check this pair against inmemory map - double mWeight = localMap.getCount(element1, element2); - if (mWeight <= 0) { - // this means we have no such pair in memory, so we'll do nothing to sWeight - } else { - // since we have new weight value in memory, we should update sWeight value before moving it off memory - sWeight += mWeight; - - // original pair can be safely removed from CountMap - localMap.removePair(element1,element2); - } - - StringBuilder builder = new StringBuilder().append(element1.getIndex()).append(" ").append(element2.getIndex()).append(" ").append(sWeight); - pw.println(builder.toString()); - numberOfLinesSaved++; - linesRead++; - - // if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]"); - // if (linesRead % 100000 == 0) logger.info("Lines read: [" + linesRead +"]"); - } - */ - /* - logger.info("Lines read: [" + linesRead + "]"); - - //now, we can dump the rest of elements, which were not presented in existing dump - Iterator> iterator = localMap.getPairIterator(); - while (iterator.hasNext()) { - Pair pair = iterator.next(); - double mWeight = localMap.getCount(pair); - - StringBuilder builder = new StringBuilder().append(pair.getFirst().getIndex()).append(" ").append(pair.getFirst().getIndex()).append(" ").append(mWeight); - pw.println(builder.toString()); - numberOfLinesSaved++; - - // if (numberOfLinesSaved % 100000 == 0) logger.info("Lines saved: [" + numberOfLinesSaved +"]"); - } - - pw.flush(); - pw.close(); - - */ - - // just a hint for gc - localMap = null; - //sIterator.finish(); - } catch (Exception e) { - throw new RuntimeException(e); - } - - logger.info("Number of word pairs saved so far: [" + numberOfLinesSaved + "]"); - isInvoked.set(false); - } - - /** - * This method provides soft finish ability for shadow copy process. - * Please note: it's blocking call, since it requires for final merge. - */ - public void finish() { - if (this.isFinished.get()) { - return; - } - - this.isFinished.set(true); - invokeBlocking(); - } - - /** - * This method provides hard fiinish ability for shadow copy process - */ - public void terminate() { - this.isTerminate.set(true); - } - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/Glove.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/Glove.java deleted file mode 100644 index 44887e8fe..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/Glove.java +++ /dev/null @@ -1,444 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove; - -import lombok.NonNull; -import org.deeplearning4j.models.embeddings.WeightLookupTable; -import org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe; -import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration; -import org.deeplearning4j.models.embeddings.reader.ModelUtils; -import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; -import org.deeplearning4j.models.sequencevectors.SequenceVectors; -import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator; -import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener; -import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator; -import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.deeplearning4j.text.documentiterator.DocumentIterator; -import org.deeplearning4j.text.sentenceiterator.SentenceIterator; -import org.deeplearning4j.text.sentenceiterator.StreamLineIterator; -import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; - -import java.util.Collection; -import java.util.List; - -/** - * GlobalVectors standalone implementation for DL4j. - * Based on original Stanford GloVe http://www-nlp.stanford.edu/pubs/glove.pdf - * - * @author raver119@gmail.com - */ -public class Glove extends SequenceVectors { - - protected Glove() { - - } - - public static class Builder extends SequenceVectors.Builder { - private double xMax; - private boolean shuffle; - private boolean symmetric; - protected double alpha = 0.75d; - private int maxmemory = (int) (Runtime.getRuntime().totalMemory() / 1024 / 1024 / 1024); - - protected TokenizerFactory tokenFactory; - protected SentenceIterator sentenceIterator; - protected DocumentIterator documentIterator; - - public Builder() { - super(); - } - - - public Builder(@NonNull VectorsConfiguration configuration) { - super(configuration); - } - - - /** - * This method has no effect for GloVe - * - * @param vec existing WordVectors model - * @return - */ - @Override - public Builder useExistingWordVectors(@NonNull WordVectors vec) { - return this; - } - - @Override - public Builder iterate(@NonNull SequenceIterator iterator) { - super.iterate(iterator); - return this; - } - - /** - * Specifies minibatch size for training process. - * - * @param batchSize - * @return - */ - @Override - public Builder batchSize(int batchSize) { - super.batchSize(batchSize); - return this; - } - - /** - * Ierations and epochs are the same in GloVe implementation. - * - * @param iterations - * @return - */ - @Override - public Builder iterations(int iterations) { - super.epochs(iterations); - return this; - } - - /** - * Sets the number of iteration over training corpus during training - * - * @param numEpochs - * @return - */ - @Override - public Builder epochs(int numEpochs) { - super.epochs(numEpochs); - return this; - } - - @Override - public Builder useAdaGrad(boolean reallyUse) { - super.useAdaGrad(true); - return this; - } - - @Override - public Builder layerSize(int layerSize) { - super.layerSize(layerSize); - return this; - } - - @Override - public Builder learningRate(double learningRate) { - super.learningRate(learningRate); - return this; - } - - /** - * Sets minimum word frequency during vocabulary mastering. - * Please note: this option is ignored, if vocabulary is built outside of GloVe - * - * @param minWordFrequency - * @return - */ - @Override - public Builder minWordFrequency(int minWordFrequency) { - super.minWordFrequency(minWordFrequency); - return this; - } - - @Override - public Builder minLearningRate(double minLearningRate) { - super.minLearningRate(minLearningRate); - return this; - } - - @Override - public Builder resetModel(boolean reallyReset) { - super.resetModel(reallyReset); - return this; - } - - @Override - public Builder vocabCache(@NonNull VocabCache vocabCache) { - super.vocabCache(vocabCache); - return this; - } - - @Override - public Builder lookupTable(@NonNull WeightLookupTable lookupTable) { - super.lookupTable(lookupTable); - return this; - } - - @Override - @Deprecated - public Builder sampling(double sampling) { - super.sampling(sampling); - return this; - } - - @Override - @Deprecated - public Builder negativeSample(double negative) { - super.negativeSample(negative); - return this; - } - - @Override - public Builder stopWords(@NonNull List stopList) { - super.stopWords(stopList); - return this; - } - - @Override - public Builder trainElementsRepresentation(boolean trainElements) { - super.trainElementsRepresentation(true); - return this; - } - - @Override - @Deprecated - public Builder trainSequencesRepresentation(boolean trainSequences) { - super.trainSequencesRepresentation(false); - return this; - } - - @Override - public Builder stopWords(@NonNull Collection stopList) { - super.stopWords(stopList); - return this; - } - - @Override - public Builder windowSize(int windowSize) { - super.windowSize(windowSize); - return this; - } - - @Override - public Builder seed(long randomSeed) { - super.seed(randomSeed); - return this; - } - - @Override - public Builder workers(int numWorkers) { - super.workers(numWorkers); - return this; - } - - /** - * Sets TokenizerFactory to be used for training - * - * @param tokenizerFactory - * @return - */ - public Builder tokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) { - this.tokenFactory = tokenizerFactory; - return this; - } - - /** - * Parameter specifying cutoff in weighting function; default 100.0 - * - * @param xMax - * @return - */ - public Builder xMax(double xMax) { - this.xMax = xMax; - return this; - } - - /** - * Parameters specifying, if cooccurrences list should be build into both directions from any current word. - * - * @param reallySymmetric - * @return - */ - public Builder symmetric(boolean reallySymmetric) { - this.symmetric = reallySymmetric; - return this; - } - - /** - * Parameter specifying, if cooccurrences list should be shuffled between training epochs - * - * @param reallyShuffle - * @return - */ - public Builder shuffle(boolean reallyShuffle) { - this.shuffle = reallyShuffle; - return this; - } - - /** - * This method has no effect for ParagraphVectors - * - * @param windows - * @return - */ - @Override - public Builder useVariableWindow(int... windows) { - // no-op - return this; - } - - /** - * Parameter in exponent of weighting function; default 0.75 - * - * @param alpha - * @return - */ - public Builder alpha(double alpha) { - this.alpha = alpha; - return this; - } - - public Builder iterate(@NonNull SentenceIterator iterator) { - this.sentenceIterator = iterator; - return this; - } - - public Builder iterate(@NonNull DocumentIterator iterator) { - this.sentenceIterator = new StreamLineIterator.Builder(iterator).setFetchSize(100).build(); - return this; - } - - /** - * Sets ModelUtils that gonna be used as provider for utility methods: similarity(), wordsNearest(), accuracy(), etc - * - * @param modelUtils model utils to be used - * @return - */ - @Override - public Builder modelUtils(@NonNull ModelUtils modelUtils) { - super.modelUtils(modelUtils); - return this; - } - - /** - * This method sets VectorsListeners for this SequenceVectors model - * - * @param vectorsListeners - * @return - */ - @Override - public Builder setVectorsListeners(@NonNull Collection> vectorsListeners) { - super.setVectorsListeners(vectorsListeners); - return this; - } - - /** - * This method allows you to specify maximum memory available for CoOccurrence map builder. - * - * Please note: this option can be considered a debugging method. In most cases setting proper -Xmx argument set to JVM is enough to limit this algorithm. - * Please note: this option won't override -Xmx JVM value. - * - * @param gbytes memory limit, in gigabytes - * @return - */ - public Builder maxMemory(int gbytes) { - this.maxmemory = gbytes; - return this; - } - - /** - * This method allows you to specify SequenceElement that will be used as UNK element, if UNK is used - * - * @param element - * @return - */ - @Override - public Builder unknownElement(VocabWord element) { - super.unknownElement(element); - return this; - } - - /** - * This method allows you to specify, if UNK word should be used internally - * - * @param reallyUse - * @return - */ - @Override - public Builder useUnknown(boolean reallyUse) { - super.useUnknown(reallyUse); - if (this.unknownElement == null) { - this.unknownElement(new VocabWord(1.0, Glove.DEFAULT_UNK)); - } - return this; - } - - public Glove build() { - presetTables(); - - Glove ret = new Glove(); - - - // hardcoded value for glove - - if (sentenceIterator != null) { - SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(sentenceIterator) - .tokenizerFactory(tokenFactory).build(); - this.iterator = new AbstractSequenceIterator.Builder<>(transformer).build(); - } - - - ret.trainElementsVectors = true; - ret.trainSequenceVectors = false; - ret.useAdeGrad = true; - this.useAdaGrad = true; - - ret.learningRate.set(this.learningRate); - ret.resetModel = this.resetModel; - ret.batchSize = this.batchSize; - ret.iterator = this.iterator; - ret.numEpochs = this.numEpochs; - ret.numIterations = this.iterations; - ret.layerSize = this.layerSize; - - ret.useUnknown = this.useUnknown; - ret.unknownElement = this.unknownElement; - - - - this.configuration.setLearningRate(this.learningRate); - this.configuration.setLayersSize(layerSize); - this.configuration.setHugeModelExpected(hugeModelExpected); - this.configuration.setWindow(window); - this.configuration.setMinWordFrequency(minWordFrequency); - this.configuration.setIterations(iterations); - this.configuration.setSeed(seed); - this.configuration.setBatchSize(batchSize); - this.configuration.setLearningRateDecayWords(learningRateDecayWords); - this.configuration.setMinLearningRate(minLearningRate); - this.configuration.setSampling(this.sampling); - this.configuration.setUseAdaGrad(useAdaGrad); - this.configuration.setNegative(negative); - this.configuration.setEpochs(this.numEpochs); - - - ret.configuration = this.configuration; - - ret.lookupTable = this.lookupTable; - ret.vocab = this.vocabCache; - ret.modelUtils = this.modelUtils; - ret.eventListeners = this.vectorsListeners; - - - ret.elementsLearningAlgorithm = new GloVe.Builder().learningRate(this.learningRate) - .shuffle(this.shuffle).symmetric(this.symmetric).xMax(this.xMax).alpha(this.alpha) - .maxMemory(maxmemory).build(); - - return ret; - } - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/GloveWeightLookupTable.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/GloveWeightLookupTable.java deleted file mode 100644 index bc52a1422..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/GloveWeightLookupTable.java +++ /dev/null @@ -1,334 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove; - - -import org.apache.commons.io.IOUtils; -import org.apache.commons.io.LineIterator; -import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; -import org.deeplearning4j.models.word2vec.Word2Vec; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.rng.Random; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.learning.legacy.AdaGrad; - -import java.io.IOException; -import java.io.InputStream; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.atomic.AtomicLong; - -/** - * Glove lookup table - * - * @author Adam Gibson - */ -// Deprecated due to logic being pulled off WeightLookupTable classes into LearningAlgorithm interfaces for better code. -@Deprecated -public class GloveWeightLookupTable extends InMemoryLookupTable { - - - private AdaGrad weightAdaGrad; - private AdaGrad biasAdaGrad; - private INDArray bias; - //also known as alpha - private double xMax = 0.75; - private double maxCount = 100; - - - public GloveWeightLookupTable(VocabCache vocab, int vectorLength, boolean useAdaGrad, double lr, Random gen, - double negative, double xMax, double maxCount) { - super(vocab, vectorLength, useAdaGrad, lr, gen, negative); - this.xMax = xMax; - this.maxCount = maxCount; - } - - @Override - public void resetWeights(boolean reset) { - if (rng == null) - this.rng = Nd4j.getRandom(); - - //note the +2 which is the unk vocab word and the bias - if (syn0 == null || reset) { - syn0 = Nd4j.rand(new int[] {vocab.numWords() + 1, vectorLength}, rng).subi(0.5).divi((double) vectorLength); - INDArray randUnk = Nd4j.rand(1, vectorLength, rng).subi(0.5).divi(vectorLength); - putVector(Word2Vec.DEFAULT_UNK, randUnk); - } - if (weightAdaGrad == null || reset) { - weightAdaGrad = new AdaGrad(new long[]{vocab.numWords() + 1, vectorLength}, lr.get()); - } - - - //right after unknown - if (bias == null || reset) - bias = Nd4j.create(syn0.rows()); - - if (biasAdaGrad == null || reset) { - biasAdaGrad = new AdaGrad(bias.shape(), lr.get()); - } - - - } - - /** - * Reset the weights of the cache - */ - @Override - public void resetWeights() { - resetWeights(true); - - } - - /** - * glove iteration - * @param w1 the first word - * @param w2 the second word - * @param score the weight learned for the particular co occurrences - */ - public double iterateSample(T w1, T w2, double score) { - INDArray w1Vector = syn0.slice(w1.getIndex()); - INDArray w2Vector = syn0.slice(w2.getIndex()); - //prediction: input + bias - if (w1.getIndex() < 0 || w1.getIndex() >= syn0.rows()) - throw new IllegalArgumentException("Illegal index for word " + w1.getLabel()); - if (w2.getIndex() < 0 || w2.getIndex() >= syn0.rows()) - throw new IllegalArgumentException("Illegal index for word " + w2.getLabel()); - - - //w1 * w2 + bias - double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector); - prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex()); - - double weight = Math.pow(Math.min(1.0, (score / maxCount)), xMax); - - double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score)); - if (Double.isNaN(fDiff)) - fDiff = Nd4j.EPS_THRESHOLD; - //amount of change - double gradient = fDiff; - - //note the update step here: the gradient is - //the gradient of the OPPOSITE word - //for adagrad we will use the index of the word passed in - //for the gradient calculation we will use the context vector - update(w1, w1Vector, w2Vector, gradient); - update(w2, w2Vector, w1Vector, gradient); - return fDiff; - - - - } - - - private void update(T w1, INDArray wordVector, INDArray contextVector, double gradient) { - //gradient for word vectors - INDArray grad1 = contextVector.mul(gradient); - INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape()); - - //update vector - wordVector.subi(update); - - double w1Bias = bias.getDouble(w1.getIndex()); - double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape()); - double update2 = w1Bias - biasGradient; - bias.putScalar(w1.getIndex(), update2); - } - - public AdaGrad getWeightAdaGrad() { - return weightAdaGrad; - } - - - public AdaGrad getBiasAdaGrad() { - return biasAdaGrad; - } - - - - /** - * Load a glove model from an input stream. - * The format is: - * word num1 num2.... - * @param is the input stream to read from for the weights - * @param vocab the vocab for the lookuptable - * @return the loaded model - * @throws java.io.IOException if one occurs - */ - public static GloveWeightLookupTable load(InputStream is, VocabCache vocab) - throws IOException { - LineIterator iter = IOUtils.lineIterator(is, "UTF-8"); - GloveWeightLookupTable glove = null; - Map wordVectors = new HashMap<>(); - while (iter.hasNext()) { - String line = iter.nextLine().trim(); - if (line.isEmpty()) - continue; - String[] split = line.split(" "); - String word = split[0]; - if (glove == null) - glove = new GloveWeightLookupTable.Builder().cache(vocab).vectorLength(split.length - 1).build(); - - - - if (word.isEmpty()) - continue; - float[] read = read(split, glove.layerSize()); - if (read.length < 1) - continue; - - - wordVectors.put(word, read); - - - - } - - glove.setSyn0(weights(glove, wordVectors, vocab)); - glove.resetWeights(false); - - - iter.close(); - - - return glove; - - } - - private static INDArray weights(GloveWeightLookupTable glove, Map data, VocabCache vocab) { - INDArray ret = Nd4j.create(data.size(), glove.layerSize()); - - for (Map.Entry entry : data.entrySet()) { - String key = entry.getKey(); - INDArray row = Nd4j.create(Nd4j.createBuffer(entry.getValue())); - if (row.length() != glove.layerSize()) - continue; - if (vocab.indexOf(key) >= data.size()) - continue; - if (vocab.indexOf(key) < 0) - continue; - ret.putRow(vocab.indexOf(key), row); - } - return ret; - } - - - private static float[] read(String[] split, int length) { - float[] ret = new float[length]; - for (int i = 1; i < split.length; i++) { - ret[i - 1] = Float.parseFloat(split[i]); - } - return ret; - } - - - @Override - public void iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha) { - throw new UnsupportedOperationException(); - - } - - public double getxMax() { - return xMax; - } - - public void setxMax(double xMax) { - this.xMax = xMax; - } - - public double getMaxCount() { - return maxCount; - } - - public void setMaxCount(double maxCount) { - this.maxCount = maxCount; - } - - public INDArray getBias() { - return bias; - } - - public void setBias(INDArray bias) { - this.bias = bias; - } - - public static class Builder extends InMemoryLookupTable.Builder { - private double xMax = 0.75; - private double maxCount = 100; - - - public Builder maxCount(double maxCount) { - this.maxCount = maxCount; - return this; - } - - - public Builder xMax(double xMax) { - this.xMax = xMax; - return this; - } - - @Override - public Builder cache(VocabCache vocab) { - super.cache(vocab); - return this; - } - - @Override - public Builder negative(double negative) { - super.negative(negative); - return this; - } - - @Override - public Builder vectorLength(int vectorLength) { - super.vectorLength(vectorLength); - return this; - } - - @Override - public Builder useAdaGrad(boolean useAdaGrad) { - super.useAdaGrad(useAdaGrad); - return this; - } - - @Override - public Builder lr(double lr) { - super.lr(lr); - return this; - } - - @Override - public Builder gen(Random gen) { - super.gen(gen); - return this; - } - - @Override - public Builder seed(long seed) { - super.seed(seed); - return this; - } - - public GloveWeightLookupTable build() { - return new GloveWeightLookupTable<>(vocabCache, vectorLength, useAdaGrad, lr, gen, negative, xMax, - maxCount); - } - } - -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceReader.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceReader.java deleted file mode 100644 index 8dd2fe85f..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceReader.java +++ /dev/null @@ -1,91 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove.count; - -import lombok.NonNull; -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; -import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator; -import org.deeplearning4j.text.sentenceiterator.SentenceIterator; - -import java.io.File; -import java.io.PrintWriter; - -/** - * @author raver119@gmail.com - */ -public class ASCIICoOccurrenceReader implements CoOccurenceReader { - private File file; - private PrintWriter writer; - private SentenceIterator iterator; - private VocabCache vocabCache; - - public ASCIICoOccurrenceReader(@NonNull File file, @NonNull VocabCache vocabCache) { - this.vocabCache = vocabCache; - this.file = file; - try { - iterator = new PrefetchingSentenceIterator.Builder(new BasicLineIterator(file)).build(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - - @Override - public boolean hasMoreObjects() { - return iterator.hasNext(); - } - - - /** - * Returns next CoOccurrenceWeight object - * - * PLEASE NOTE: This method can return null value. - * @return - */ - @Override - public CoOccurrenceWeight nextObject() { - String line = iterator.nextSentence(); - if (line == null || line.isEmpty()) { - return null; - } - String[] strings = line.split(" "); - - CoOccurrenceWeight object = new CoOccurrenceWeight<>(); - object.setElement1(vocabCache.elementAtIndex(Integer.valueOf(strings[0]))); - object.setElement2(vocabCache.elementAtIndex(Integer.valueOf(strings[1]))); - object.setWeight(Double.parseDouble(strings[2])); - - return object; - } - - - - @Override - public void finish() { - try { - if (writer != null) { - writer.flush(); - writer.close(); - } - } catch (Exception e) { - throw new RuntimeException(e); - } - - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceWriter.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceWriter.java deleted file mode 100644 index 4ef4aada5..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/ASCIICoOccurrenceWriter.java +++ /dev/null @@ -1,69 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove.count; - -import lombok.NonNull; -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; - -import java.io.BufferedOutputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.io.PrintWriter; - -/** - * @author raver119@gmail.com - */ -public class ASCIICoOccurrenceWriter implements CoOccurrenceWriter { - - private File file; - private PrintWriter writer; - - public ASCIICoOccurrenceWriter(@NonNull File file) { - this.file = file; - try { - this.writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(file), 10 * 1024 * 1024)); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void writeObject(CoOccurrenceWeight object) { - StringBuilder builder = new StringBuilder(String.valueOf(object.getElement1().getIndex())).append(" ") - .append(String.valueOf(object.getElement2().getIndex())).append(" ") - .append(String.valueOf(object.getWeight())); - writer.println(builder.toString()); - } - - @Override - public void queueObject(CoOccurrenceWeight object) { - throw new UnsupportedOperationException(); - } - - @Override - public void finish() { - try { - writer.flush(); - } catch (Exception e) { - } - - try { - writer.close(); - } catch (Exception e) { - } - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReader.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReader.java deleted file mode 100644 index 549cb9aca..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReader.java +++ /dev/null @@ -1,245 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove.count; - -import lombok.NonNull; -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.BufferedInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; - -/** - * Binary implementation of CoOccurenceReader interface, used to provide off-memory storage for cooccurrence maps generated for GloVe - * - * @author raver119@gmail.com - */ -public class BinaryCoOccurrenceReader implements CoOccurenceReader { - private VocabCache vocabCache; - private InputStream inputStream; - private File file; - private ArrayBlockingQueue> buffer; - int workers = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1); - private StreamReaderThread readerThread; - private CountMap countMap; - - - protected static final Logger logger = LoggerFactory.getLogger(BinaryCoOccurrenceReader.class); - - public BinaryCoOccurrenceReader(@NonNull File file, @NonNull VocabCache vocabCache, CountMap map) { - this.vocabCache = vocabCache; - this.file = file; - this.countMap = map; - buffer = new ArrayBlockingQueue<>(200000); - - try { - inputStream = new BufferedInputStream(new FileInputStream(this.file), 100 * 1024 * 1024); - //inputStream = new BufferedInputStream(new FileInputStream(file), 1024 * 1024); - readerThread = new StreamReaderThread(inputStream); - readerThread.start(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public boolean hasMoreObjects() { - - if (!buffer.isEmpty()) - return true; - - try { - return readerThread.hasMoreObjects() || !buffer.isEmpty(); - } catch (Exception e) { - throw new RuntimeException(e); - //return false; - } - } - - @Override - public CoOccurrenceWeight nextObject() { - if (!buffer.isEmpty()) { - return buffer.poll(); - } else { - // buffer can be starved, or we're already at the end of file. - if (readerThread.hasMoreObjects()) { - try { - return buffer.poll(3, TimeUnit.SECONDS); - } catch (Exception e) { - return null; - } - } - } - - - return null; - /* - try { - CoOccurrenceWeight ret = new CoOccurrenceWeight<>(); - ret.setElement1(vocabCache.elementAtIndex(inputStream.readInt())); - ret.setElement2(vocabCache.elementAtIndex(inputStream.readInt())); - ret.setWeight(inputStream.readDouble()); - - return ret; - } catch (Exception e) { - return null; - } - */ - } - - @Override - public void finish() { - try { - if (inputStream != null) - inputStream.close(); - } catch (Exception e) { - // - } - } - - private class StreamReaderThread extends Thread implements Runnable { - private InputStream stream; - private AtomicBoolean isReading = new AtomicBoolean(false); - - public StreamReaderThread(@NonNull InputStream stream) { - this.stream = stream; - isReading.set(false); - } - - @Override - public void run() { - try { - // we read pre-defined number of objects as byte array - byte[] array = new byte[16 * 500000]; - while (true) { - int count = stream.read(array); - - isReading.set(true); - if (count == 0) - break; - - // now we deserialize them in separate threads to gain some speedup, if possible - List threads = new ArrayList<>(); - AtomicInteger internalPosition = new AtomicInteger(0); - - for (int t = 0; t < workers; t++) { - threads.add(t, new AsyncDeserializationThread(t, array, buffer, internalPosition, count)); - threads.get(t).start(); - } - - // we'll block this cycle untill all objects are fit into queue - for (int t = 0; t < workers; t++) { - try { - threads.get(t).join(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - isReading.set(false); - if (count < array.length) - break; - } - - } catch (Exception e) { - isReading.set(false); - throw new RuntimeException(e); - } - } - - public boolean hasMoreObjects() { - try { - return stream.available() > 0 || isReading.get(); - } catch (Exception e) { - return false; - } finally { - } - } - } - - /** - * Utility class that accepts byte array as input, and deserialize it into set of CoOccurrenceWeight objects - */ - private class AsyncDeserializationThread extends Thread implements Runnable { - private int threadId; - private byte[] arrayReference; - private ArrayBlockingQueue> targetBuffer; - private AtomicInteger pointer; - private int limit; - - public AsyncDeserializationThread(int threadId, @NonNull byte[] array, - @NonNull ArrayBlockingQueue> targetBuffer, - @NonNull AtomicInteger sharedPointer, int limit) { - this.threadId = threadId; - this.arrayReference = array; - this.targetBuffer = targetBuffer; - this.pointer = sharedPointer; - this.limit = limit; - - - setName("AsynDeserialization thread " + this.threadId); - } - - @Override - public void run() { - ByteBuffer bB = ByteBuffer.wrap(arrayReference); - int position = 0; - while ((position = pointer.getAndAdd(16)) < this.limit) { - if (position >= limit) { - continue; - } - - - int e1idx = bB.getInt(position); - int e2idx = bB.getInt(position + 4); - double eW = bB.getDouble(position + 8); - - - CoOccurrenceWeight object = new CoOccurrenceWeight<>(); - object.setElement1(vocabCache.elementAtIndex(e1idx)); - object.setElement2(vocabCache.elementAtIndex(e2idx)); - - if (countMap != null) { - double mW = countMap.getCount(object.getElement1(), object.getElement2()); - - if (mW > 0) { - eW += mW; - countMap.removePair(object.getElement1(), object.getElement2()); - } - } - object.setWeight(eW); - - try { - targetBuffer.put(object); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - } - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceWriter.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceWriter.java deleted file mode 100644 index 81230802e..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceWriter.java +++ /dev/null @@ -1,78 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove.count; - -import lombok.NonNull; -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.BufferedOutputStream; -import java.io.DataOutputStream; -import java.io.File; -import java.io.FileOutputStream; - -/** - * @author raver119@gmail.com - */ -public class BinaryCoOccurrenceWriter implements CoOccurrenceWriter { - private File file; - private DataOutputStream outputStream; - - private static final Logger log = LoggerFactory.getLogger(BinaryCoOccurrenceWriter.class); - - public BinaryCoOccurrenceWriter(@NonNull File file) { - this.file = file; - - try { - outputStream = new DataOutputStream( - new BufferedOutputStream(new FileOutputStream(file), 100 * 1024 * 1024)); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void writeObject(@NonNull CoOccurrenceWeight object) { - try { - // log.info("Saving objects: { [" +object.getElement1().getIndex() +"], [" + object.getElement2().getIndex() + "] }"); - outputStream.writeInt(object.getElement1().getIndex()); - outputStream.writeInt(object.getElement2().getIndex()); - outputStream.writeDouble(object.getWeight()); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void queueObject(CoOccurrenceWeight object) { - throw new UnsupportedOperationException(); - } - - @Override - public void finish() { - try { - outputStream.flush(); - } catch (Exception e) { - } - - try { - outputStream.close(); - } catch (Exception e) { - } - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurenceReader.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurenceReader.java deleted file mode 100644 index 0eaecc00b..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurenceReader.java +++ /dev/null @@ -1,34 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove.count; - -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; - -/** - * Created by raver on 24.12.2015. - */ -public interface CoOccurenceReader { - /* - Storage->Memory merging part - */ - boolean hasMoreObjects(); - - - CoOccurrenceWeight nextObject(); - - void finish(); -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWeight.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWeight.java deleted file mode 100644 index 251163e0a..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWeight.java +++ /dev/null @@ -1,54 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove.count; - -import lombok.Data; -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; - -/** - * Simple POJO holding pairs of elements and their respective weights, used in GloVe -> CoOccurrence - * - * @author raver119@gmail.com - */ -@Data -public class CoOccurrenceWeight { - private T element1; - private T element2; - private double weight; - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - CoOccurrenceWeight that = (CoOccurrenceWeight) o; - - if (element1 != null ? !element1.equals(that.element1) : that.element1 != null) - return false; - return element2 != null ? element2.equals(that.element2) : that.element2 == null; - - } - - @Override - public int hashCode() { - int result = element1 != null ? element1.hashCode() : 0; - result = 31 * result + (element2 != null ? element2.hashCode() : 0); - return result; - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWriter.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWriter.java deleted file mode 100644 index b7f7a21ea..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CoOccurrenceWriter.java +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove.count; - -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; - -/** - * Created by fartovii on 25.12.15. - */ -public interface CoOccurrenceWriter { - - /** - * This method implementations should write out objects immediately - * @param object - */ - void writeObject(CoOccurrenceWeight object); - - /** - * This method implementations should queue objects for writing out. - * - * @param object - */ - void queueObject(CoOccurrenceWeight object); - - /** - * Implementations of this method should close everything they use, before eradication - */ - void finish(); -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java deleted file mode 100644 index 274551ebf..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java +++ /dev/null @@ -1,99 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove.count; - -import org.nd4j.shade.guava.util.concurrent.AtomicDouble; -import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; -import org.nd4j.common.primitives.Pair; - -import java.util.Iterator; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -/** - * Drop-in replacement for CounterMap - * - * WORK IN PROGRESS, PLEASE DO NOT USE - * - * @author raver119@gmail.com - */ -public class CountMap { - private volatile Map, AtomicDouble> backingMap = new ConcurrentHashMap<>(); - - public CountMap() { - // placeholder - } - - public void incrementCount(T element1, T element2, double weight) { - Pair tempEntry = new Pair<>(element1, element2); - if (backingMap.containsKey(tempEntry)) { - backingMap.get(tempEntry).addAndGet(weight); - } else { - backingMap.put(tempEntry, new AtomicDouble(weight)); - } - } - - public void removePair(T element1, T element2) { - Pair tempEntry = new Pair<>(element1, element2); - backingMap.remove(tempEntry); - } - - public void removePair(Pair pair) { - backingMap.remove(pair); - } - - public double getCount(T element1, T element2) { - Pair tempEntry = new Pair<>(element1, element2); - if (backingMap.containsKey(tempEntry)) { - return backingMap.get(tempEntry).get(); - } else - return 0; - } - - public double getCount(Pair pair) { - if (backingMap.containsKey(pair)) { - return backingMap.get(pair).get(); - } else - return 0; - } - - public Iterator> getPairIterator() { - return new Iterator>() { - private Iterator> iterator = backingMap.keySet().iterator(); - - @Override - public boolean hasNext() { - return iterator.hasNext(); - } - - @Override - public Pair next() { - //MapEntry entry = iterator.next(); - return iterator.next(); //new Pair<>(entry.getElement1(), entry.getElement2()); - } - - @Override - public void remove() { - throw new UnsupportedOperationException("remove() isn't supported here"); - } - }; - } - - public int size() { - return backingMap.size(); - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/RoundCount.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/RoundCount.java deleted file mode 100644 index d9f729dba..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/RoundCount.java +++ /dev/null @@ -1,86 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove.count; - -import java.util.concurrent.locks.ReentrantReadWriteLock; - -/** - * Simple circular counter, that circulates within 0...Limit, both inclusive - * - * @author raver119@gmail.com - */ -public class RoundCount { - - private int limit = 0; - private int lower = 0; - private int value = 0; - - private ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); - - /** - * Creates new RoundCount instance. - * - * @param limit Maximum top value for this counter. Inclusive. - */ - public RoundCount(int limit) { - this.limit = limit; - } - - /** - * Creates new RoundCount instance. - * - * @param lower - Minimum value for this counter. Inclusive - * @param top - Maximum value for this counter. Inclusive. - */ - public RoundCount(int lower, int top) { - this.limit = top; - this.lower = lower; - } - - public int previous() { - try { - lock.readLock().lock(); - if (value == lower) - return limit; - else - return value - 1; - } finally { - lock.readLock().unlock(); - } - } - - public int get() { - try { - lock.readLock().lock(); - return value; - } finally { - lock.readLock().unlock(); - } - } - - public void tick() { - try { - lock.writeLock().lock(); - if (value == limit) - value = lower; - else - value++; - } finally { - lock.writeLock().unlock(); - } - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java index c007d4b96..e27debd50 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java @@ -763,7 +763,7 @@ public class ParagraphVectors extends Word2Vec { /** - * This method allows you to use pre-built WordVectors model (Word2Vec or GloVe) for ParagraphVectors. + * This method allows you to use pre-built WordVectors model (e.g. Word2Vec) for ParagraphVectors. * Existing model will be transferred into new model before training starts. * * PLEASE NOTE: Non-normalized model is recommended to use here. diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java index d31cc51b0..0e104bb20 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java @@ -520,7 +520,7 @@ public class SequenceVectors extends WordVectorsImpl< } /** - * This method allows you to use pre-built WordVectors model (SkipGram or GloVe) for DBOW sequence learning. + * This method allows you to use pre-built WordVectors model (e.g. SkipGram) for DBOW sequence learning. * Existing model will be transferred into new model before training starts. * * PLEASE NOTE: This model has no effect for elements learning algorithms. Only sequence learning is affected. diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/AbstractCoOccurrencesTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/AbstractCoOccurrencesTest.java deleted file mode 100644 index 8d59b2a5a..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/AbstractCoOccurrencesTest.java +++ /dev/null @@ -1,101 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove; - -import org.deeplearning4j.BaseDL4JTest; -import org.nd4j.common.io.ClassPathResource; -import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator; -import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor; -import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache; -import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; -import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; -import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; -import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Before; -import org.junit.Test; -import org.nd4j.common.primitives.Pair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; - -/** - * @author raver119@gmail.com - */ -public class AbstractCoOccurrencesTest extends BaseDL4JTest { - - private static final Logger log = LoggerFactory.getLogger(AbstractCoOccurrencesTest.class); - - @Before - public void setUp() throws Exception { - - } - - @Test - public void testFit1() throws Exception { - ClassPathResource resource = new ClassPathResource("other/oneline.txt"); - File file = resource.getFile(); - - AbstractCache vocabCache = new AbstractCache.Builder().build(); - BasicLineIterator underlyingIterator = new BasicLineIterator(file); - - TokenizerFactory t = new DefaultTokenizerFactory(); - t.setTokenPreProcessor(new CommonPreprocessor()); - - SentenceTransformer transformer = - new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build(); - - AbstractSequenceIterator sequenceIterator = - new AbstractSequenceIterator.Builder<>(transformer).build(); - - VocabConstructor constructor = new VocabConstructor.Builder() - .addSource(sequenceIterator, 1).setTargetVocabCache(vocabCache).build(); - - constructor.buildJointVocabulary(false, true); - - AbstractCoOccurrences coOccurrences = new AbstractCoOccurrences.Builder() - .iterate(sequenceIterator).vocabCache(vocabCache).symmetric(false).windowSize(15).build(); - - coOccurrences.fit(); - - //List> list = coOccurrences.i(); - Iterator, Double>> iterator = coOccurrences.iterator(); - assertNotEquals(null, iterator); - int cnt = 0; - - List> list = new ArrayList<>(); - while (iterator.hasNext()) { - Pair, Double> pair = iterator.next(); - list.add(pair.getFirst()); - cnt++; - } - - - log.info("CoOccurrences: " + list); - - assertEquals(16, list.size()); - assertEquals(16, cnt); - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/GloveTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/GloveTest.java deleted file mode 100644 index 39aa40d10..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/GloveTest.java +++ /dev/null @@ -1,137 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove; - -import org.deeplearning4j.BaseDL4JTest; -import org.nd4j.common.io.ClassPathResource; -import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; -import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; -import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; -import org.deeplearning4j.text.sentenceiterator.LineSentenceIterator; -import org.deeplearning4j.text.sentenceiterator.SentenceIterator; -import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor; -import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; -import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; -import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.resources.Resources; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.util.Collection; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -/** - * Created by agibsonccc on 12/3/14. - */ -public class GloveTest extends BaseDL4JTest { - private static final Logger log = LoggerFactory.getLogger(GloveTest.class); - private Glove glove; - private SentenceIterator iter; - - @Before - public void before() throws Exception { - - ClassPathResource resource = new ClassPathResource("/raw_sentences.txt"); - File file = resource.getFile(); - iter = new LineSentenceIterator(file); - iter.setPreProcessor(new SentencePreProcessor() { - @Override - public String preProcess(String sentence) { - return sentence.toLowerCase(); - } - }); - - } - - - @Ignore - @Test - public void testGlove() throws Exception { - /* - glove = new Glove.Builder().iterate(iter).symmetric(true).shuffle(true) - .minWordFrequency(1).iterations(10).learningRate(0.1) - .layerSize(300) - .build(); - - glove.fit(); - Collection words = glove.wordsNearest("day", 20); - log.info("Nearest words to 'day': " + words); - assertTrue(words.contains("week")); - - */ - - } - - @Ignore - @Test - public void testGloVe1() throws Exception { - File inputFile = Resources.asFile("big/raw_sentences.txt"); - - SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); - // Split on white spaces in the line to get words - TokenizerFactory t = new DefaultTokenizerFactory(); - t.setTokenPreProcessor(new CommonPreprocessor()); - - Glove glove = new Glove.Builder().iterate(iter).tokenizerFactory(t).alpha(0.75).learningRate(0.1).epochs(45) - .xMax(100).shuffle(true).symmetric(true).build(); - - glove.fit(); - - double simD = glove.similarity("day", "night"); - double simP = glove.similarity("best", "police"); - - - - log.info("Day/night similarity: " + simD); - log.info("Best/police similarity: " + simP); - - Collection words = glove.wordsNearest("day", 10); - log.info("Nearest words to 'day': " + words); - - - assertTrue(simD > 0.7); - - // actually simP should be somewhere at 0 - assertTrue(simP < 0.5); - - assertTrue(words.contains("night")); - assertTrue(words.contains("year")); - assertTrue(words.contains("week")); - - File tempFile = File.createTempFile("glove", "temp"); - tempFile.deleteOnExit(); - - INDArray day1 = glove.getWordVectorMatrix("day").dup(); - - WordVectorSerializer.writeWordVectors(glove, tempFile); - - WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile); - - INDArray day2 = vectors.getWordVectorMatrix("day").dup(); - - assertEquals(day1, day2); - - tempFile.delete(); - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReaderTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReaderTest.java deleted file mode 100644 index 7f357a901..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/BinaryCoOccurrenceReaderTest.java +++ /dev/null @@ -1,156 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove.count; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.models.word2vec.Huffman; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache; -import org.junit.Before; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; - -import static org.junit.Assert.assertNotEquals; - -/** - * Created by fartovii on 25.12.15. - */ -public class BinaryCoOccurrenceReaderTest extends BaseDL4JTest { - - private static final Logger log = LoggerFactory.getLogger(BinaryCoOccurrenceReaderTest.class); - - @Before - public void setUp() throws Exception { - - } - - @Test - public void testHasMoreObjects1() throws Exception { - File tempFile = File.createTempFile("tmp", "tmp"); - tempFile.deleteOnExit(); - - VocabCache vocabCache = new AbstractCache.Builder().build(); - - VocabWord word1 = new VocabWord(1.0, "human"); - VocabWord word2 = new VocabWord(2.0, "animal"); - VocabWord word3 = new VocabWord(3.0, "unknown"); - - vocabCache.addToken(word1); - vocabCache.addToken(word2); - vocabCache.addToken(word3); - - Huffman huffman = new Huffman(vocabCache.vocabWords()); - huffman.build(); - huffman.applyIndexes(vocabCache); - - - BinaryCoOccurrenceWriter writer = new BinaryCoOccurrenceWriter<>(tempFile); - - CoOccurrenceWeight object1 = new CoOccurrenceWeight<>(); - object1.setElement1(word1); - object1.setElement2(word2); - object1.setWeight(3.14159265); - - writer.writeObject(object1); - - CoOccurrenceWeight object2 = new CoOccurrenceWeight<>(); - object2.setElement1(word2); - object2.setElement2(word3); - object2.setWeight(0.197); - - writer.writeObject(object2); - - writer.finish(); - - BinaryCoOccurrenceReader reader = new BinaryCoOccurrenceReader<>(tempFile, vocabCache, null); - - - CoOccurrenceWeight r1 = reader.nextObject(); - log.info("Object received: " + r1); - assertNotEquals(null, r1); - - r1 = reader.nextObject(); - log.info("Object received: " + r1); - assertNotEquals(null, r1); - } - - @Test - public void testHasMoreObjects2() throws Exception { - File tempFile = File.createTempFile("tmp", "tmp"); - tempFile.deleteOnExit(); - - VocabCache vocabCache = new AbstractCache.Builder().build(); - - VocabWord word1 = new VocabWord(1.0, "human"); - VocabWord word2 = new VocabWord(2.0, "animal"); - VocabWord word3 = new VocabWord(3.0, "unknown"); - - vocabCache.addToken(word1); - vocabCache.addToken(word2); - vocabCache.addToken(word3); - - Huffman huffman = new Huffman(vocabCache.vocabWords()); - huffman.build(); - huffman.applyIndexes(vocabCache); - - - BinaryCoOccurrenceWriter writer = new BinaryCoOccurrenceWriter<>(tempFile); - - CoOccurrenceWeight object1 = new CoOccurrenceWeight<>(); - object1.setElement1(word1); - object1.setElement2(word2); - object1.setWeight(3.14159265); - - writer.writeObject(object1); - - CoOccurrenceWeight object2 = new CoOccurrenceWeight<>(); - object2.setElement1(word2); - object2.setElement2(word3); - object2.setWeight(0.197); - - writer.writeObject(object2); - - CoOccurrenceWeight object3 = new CoOccurrenceWeight<>(); - object3.setElement1(word1); - object3.setElement2(word3); - object3.setWeight(0.001); - - writer.writeObject(object3); - - writer.finish(); - - BinaryCoOccurrenceReader reader = new BinaryCoOccurrenceReader<>(tempFile, vocabCache, null); - - - CoOccurrenceWeight r1 = reader.nextObject(); - log.info("Object received: " + r1); - assertNotEquals(null, r1); - - r1 = reader.nextObject(); - log.info("Object received: " + r1); - assertNotEquals(null, r1); - - r1 = reader.nextObject(); - log.info("Object received: " + r1); - assertNotEquals(null, r1); - - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/RoundCountTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/RoundCountTest.java deleted file mode 100644 index 737533648..000000000 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/glove/count/RoundCountTest.java +++ /dev/null @@ -1,90 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.models.glove.count; - -import org.deeplearning4j.BaseDL4JTest; -import org.junit.Before; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -/** - * Created by fartovii on 23.12.15. - */ -public class RoundCountTest extends BaseDL4JTest { - - @Before - public void setUp() throws Exception { - - } - - @Test - public void testGet1() throws Exception { - RoundCount count = new RoundCount(1); - - assertEquals(0, count.get()); - - count.tick(); - assertEquals(1, count.get()); - - count.tick(); - assertEquals(0, count.get()); - } - - @Test - public void testGet2() throws Exception { - RoundCount count = new RoundCount(3); - - assertEquals(0, count.get()); - - count.tick(); - assertEquals(1, count.get()); - - count.tick(); - assertEquals(2, count.get()); - - count.tick(); - assertEquals(3, count.get()); - - count.tick(); - assertEquals(0, count.get()); - } - - @Test - public void testPrevious1() throws Exception { - RoundCount count = new RoundCount(3); - - assertEquals(0, count.get()); - assertEquals(3, count.previous()); - - count.tick(); - assertEquals(1, count.get()); - assertEquals(0, count.previous()); - - count.tick(); - assertEquals(2, count.get()); - assertEquals(1, count.previous()); - - count.tick(); - assertEquals(3, count.get()); - assertEquals(2, count.previous()); - - count.tick(); - assertEquals(0, count.get()); - assertEquals(3, count.previous()); - } -} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java index b3ba6e198..6ec46bb7d 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java @@ -21,12 +21,10 @@ import lombok.Getter; import lombok.Setter; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; -import org.deeplearning4j.BaseDL4JTest; -import org.nd4j.common.io.ClassPathResource; import org.datavec.api.writable.Writable; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; -import org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe; import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; @@ -55,6 +53,7 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.common.io.ClassPathResource; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.heartbeat.Heartbeat; import org.slf4j.Logger; @@ -270,65 +269,6 @@ public class SequenceVectorsTest extends BaseDL4JTest { .epochs(1).resetModel(false).trainElementsRepresentation(false).build(); } - @Ignore - @Test - public void testGlove1() throws Exception { - logger.info("Max available memory: " + Runtime.getRuntime().maxMemory()); - ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt"); - File file = resource.getFile(); - - BasicLineIterator underlyingIterator = new BasicLineIterator(file); - - TokenizerFactory t = new DefaultTokenizerFactory(); - t.setTokenPreProcessor(new CommonPreprocessor()); - - SentenceTransformer transformer = - new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build(); - - AbstractSequenceIterator sequenceIterator = - new AbstractSequenceIterator.Builder<>(transformer).build(); - - VectorsConfiguration configuration = new VectorsConfiguration(); - configuration.setWindow(5); - configuration.setLearningRate(0.06); - configuration.setLayersSize(100); - - - SequenceVectors vectors = new SequenceVectors.Builder(configuration) - .iterate(sequenceIterator).iterations(1).epochs(45) - .elementsLearningAlgorithm(new GloVe.Builder().shuffle(true).symmetric(true) - .learningRate(0.05).alpha(0.75).xMax(100.0).build()) - .resetModel(true).trainElementsRepresentation(true).trainSequencesRepresentation(false).build(); - - vectors.fit(); - - double sim = vectors.similarity("day", "night"); - logger.info("Day/night similarity: " + sim); - - - sim = vectors.similarity("day", "another"); - logger.info("Day/another similarity: " + sim); - - sim = vectors.similarity("night", "year"); - logger.info("Night/year similarity: " + sim); - - sim = vectors.similarity("night", "me"); - logger.info("Night/me similarity: " + sim); - - sim = vectors.similarity("day", "know"); - logger.info("Day/know similarity: " + sim); - - sim = vectors.similarity("best", "police"); - logger.info("Best/police similarity: " + sim); - - Collection labels = vectors.wordsNearest("day", 10); - logger.info("Nearest labels to 'day': " + labels); - - - sim = vectors.similarity("day", "night"); - assertTrue(sim > 0.6d); - } - @Test @Ignore public void testDeepWalk() throws Exception { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/Glove.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/Glove.java deleted file mode 100644 index 81de8effb..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/Glove.java +++ /dev/null @@ -1,280 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.models.embeddings.glove; - -import org.apache.commons.math3.util.FastMath; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.broadcast.Broadcast; -import org.deeplearning4j.models.glove.GloveWeightLookupTable; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator; -import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts; -import org.deeplearning4j.spark.text.functions.TextPipeline; -import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.learning.legacy.AdaGrad; -import org.nd4j.common.primitives.CounterMap; -import org.nd4j.common.primitives.Pair; -import org.nd4j.common.primitives.Triple; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import scala.Tuple2; - -import java.io.Serializable; -import java.util.*; -import java.util.concurrent.atomic.AtomicLong; - -import static org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables.*; - -/** - * Spark glove - * - * @author Adam Gibson - */ -public class Glove implements Serializable { - - private Broadcast> vocabCacheBroadcast; - private String tokenizerFactoryClazz = DefaultTokenizerFactory.class.getName(); - private boolean symmetric = true; - private int windowSize = 15; - private int iterations = 300; - private static Logger log = LoggerFactory.getLogger(Glove.class); - - /** - * - * @param tokenizerFactoryClazz the fully qualified class name of the tokenizer - * @param symmetric whether the co occurrence counts should be symmetric - * @param windowSize the window size for co occurrence - * @param iterations the number of iterations - */ - public Glove(String tokenizerFactoryClazz, boolean symmetric, int windowSize, int iterations) { - this.tokenizerFactoryClazz = tokenizerFactoryClazz; - this.symmetric = symmetric; - this.windowSize = windowSize; - this.iterations = iterations; - } - - /** - * - * @param symmetric whether the co occurrence counts should be symmetric - * @param windowSize the window size for co occurrence - * @param iterations the number of iterations - */ - public Glove(boolean symmetric, int windowSize, int iterations) { - this.symmetric = symmetric; - this.windowSize = windowSize; - this.iterations = iterations; - } - - - private Pair update(AdaGrad weightAdaGrad, AdaGrad biasAdaGrad, INDArray syn0, INDArray bias, - VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) { - //gradient for word vectors - INDArray grad1 = contextVector.mul(gradient); - INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape()); - wordVector.subi(update); - - double w1Bias = bias.getDouble(w1.getIndex()); - double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape()); - double update2 = w1Bias - biasGradient; - bias.putScalar(w1.getIndex(), bias.getDouble(w1.getIndex()) - update2); - return new Pair<>(update, (float) update2); - } - - /** - * Train on the corpus - * @param rdd the rdd to train - * @return the vocab and weights - */ - public Pair, GloveWeightLookupTable> train(JavaRDD rdd) throws Exception { - // Each `train()` can use different parameters - final JavaSparkContext sc = new JavaSparkContext(rdd.context()); - final SparkConf conf = sc.getConf(); - final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class); - final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class); - final double negative = assignVar(NEGATIVE, conf, Double.class); - final int numWords = assignVar(NUM_WORDS, conf, Integer.class); - final int window = assignVar(WINDOW, conf, Integer.class); - final double alpha = assignVar(ALPHA, conf, Double.class); - final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class); - final int iterations = assignVar(ITERATIONS, conf, Integer.class); - final int nGrams = assignVar(N_GRAMS, conf, Integer.class); - final String tokenizer = assignVar(TOKENIZER, conf, String.class); - final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class); - final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class); - - Map tokenizerVarMap = new HashMap() { - { - put("numWords", numWords); - put("nGrams", nGrams); - put("tokenizer", tokenizer); - put("tokenPreprocessor", tokenPreprocessor); - put("removeStop", removeStop); - } - }; - Broadcast> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap); - - - TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap); - pipeline.buildVocabCache(); - pipeline.buildVocabWordListRDD(); - - - // Get total word count - Long totalWordCount = pipeline.getTotalWordCount(); - VocabCache vocabCache = pipeline.getVocabCache(); - JavaRDD, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD(); - final Pair, Long> vocabAndNumWords = new Pair<>(vocabCache, totalWordCount); - - vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst()); - - final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder() - .cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01)) - .maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100)) - .vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300)) - .xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75)).build(); - gloveWeightLookupTable.resetWeights(); - - gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows()); - gloveWeightLookupTable.getWeightAdaGrad().historicalGradient = - Nd4j.ones(gloveWeightLookupTable.getSyn0().shape()); - - - log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape())); - CounterMap coOccurrenceCounts = sentenceWordsCountRDD - .map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize)) - .fold(new CounterMap(), new CoOccurrenceCounts()); - Iterator> pair2 = coOccurrenceCounts.getIterator(); - List> counts = new ArrayList<>(); - - while (pair2.hasNext()) { - Pair next = pair2.next(); - if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) { - coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(), - (float) gloveWeightLookupTable.getMaxCount()); - } - counts.add(new Triple<>(next.getFirst(), next.getSecond(), - (float) coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()))); - - } - - log.info("Calculated co occurrences"); - - JavaRDD> parallel = sc.parallelize(counts); - JavaPairRDD> pairs = parallel - .mapToPair(new PairFunction, String, Tuple2>() { - @Override - public Tuple2> call( - Triple stringStringDoubleTriple) throws Exception { - return new Tuple2<>(stringStringDoubleTriple.getFirst(), - new Tuple2<>(stringStringDoubleTriple.getSecond(), - stringStringDoubleTriple.getThird())); - } - }); - - JavaPairRDD> pairsVocab = pairs.mapToPair( - new PairFunction>, VocabWord, Tuple2>() { - @Override - public Tuple2> call( - Tuple2> stringTuple2Tuple2) throws Exception { - VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1()); - VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1()); - return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2())); - } - }); - - - for (int i = 0; i < iterations; i++) { - JavaRDD change = - pairsVocab.map(new Function>, GloveChange>() { - @Override - public GloveChange call( - Tuple2> vocabWordTuple2Tuple2) - throws Exception { - VocabWord w1 = vocabWordTuple2Tuple2._1(); - VocabWord w2 = vocabWordTuple2Tuple2._2()._1(); - INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex()); - INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex()); - INDArray bias = gloveWeightLookupTable.getBias(); - double score = vocabWordTuple2Tuple2._2()._2(); - double xMax = gloveWeightLookupTable.getxMax(); - double maxCount = gloveWeightLookupTable.getMaxCount(); - //w1 * w2 + bias - double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector); - prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex()); - - double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax); - - double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score)); - if (Double.isNaN(fDiff)) - fDiff = Nd4j.EPS_THRESHOLD; - //amount of change - double gradient = fDiff; - - Pair w1Update = update(gloveWeightLookupTable.getWeightAdaGrad(), - gloveWeightLookupTable.getBiasAdaGrad(), - gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), - w1, w1Vector, w2Vector, gradient); - Pair w2Update = update(gloveWeightLookupTable.getWeightAdaGrad(), - gloveWeightLookupTable.getBiasAdaGrad(), - gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), - w2, w2Vector, w1Vector, gradient); - return new GloveChange(w1, w2, w1Update.getFirst(), w2Update.getFirst(), - w1Update.getSecond(), w2Update.getSecond(), fDiff, - gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient() - .slice(w1.getIndex()), - gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient() - .slice(w2.getIndex()), - gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient() - .getDouble(w2.getIndex()), - gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient() - .getDouble(w1.getIndex())); - - } - }); - - - - List gloveChanges = change.collect(); - double error = 0.0; - for (GloveChange change2 : gloveChanges) { - change2.apply(gloveWeightLookupTable); - error += change2.getError(); - } - - - List l = pairsVocab.collect(); - Collections.shuffle(l); - pairsVocab = sc.parallelizePairs(l); - - log.info("Error at iteration " + i + " was " + error); - - - - } - - return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable); - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/GloveChange.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/GloveChange.java deleted file mode 100644 index 64a5e2541..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/GloveChange.java +++ /dev/null @@ -1,163 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.models.embeddings.glove; - -import org.deeplearning4j.models.glove.GloveWeightLookupTable; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.io.Serializable; - -/** - * @author Adam Gibson - */ -public class GloveChange implements Serializable { - private VocabWord w1, w2; - private INDArray w1Update, w2Update; - private double w1BiasUpdate, w2BiasUpdate; - private double error; - private INDArray w1History, w2History; - private double w1BiasHistory, w2BiasHistory; - - public GloveChange(VocabWord w1, VocabWord w2, INDArray w1Update, INDArray w2Update, double w1BiasUpdate, - double w2BiasUpdate, double error, INDArray w1History, INDArray w2History, double w1BiasHistory, - double w2BiasHistory) { - this.w1 = w1; - this.w2 = w2; - this.w1Update = w1Update; - this.w2Update = w2Update; - this.w1BiasUpdate = w1BiasUpdate; - this.w2BiasUpdate = w2BiasUpdate; - this.error = error; - this.w1History = w1History; - this.w2History = w2History; - this.w1BiasHistory = w1BiasHistory; - this.w2BiasHistory = w2BiasHistory; - } - - /** - * Apply the changes to the table - * @param table - */ - public void apply(GloveWeightLookupTable table) { - table.getBias().putScalar(w1.getIndex(), table.getBias().getDouble(w1.getIndex()) - w1BiasUpdate); - table.getBias().putScalar(w2.getIndex(), table.getBias().getDouble(w2.getIndex()) - w2BiasUpdate); - table.getSyn0().slice(w1.getIndex()).subi(w1Update); - table.getSyn0().slice(w2.getIndex()).subi(w2Update); - table.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()).addi(w1History); - table.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()).addi(w2History); - table.getBiasAdaGrad().getHistoricalGradient().putScalar(w1.getIndex(), - table.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()) + w1BiasHistory); - table.getBiasAdaGrad().getHistoricalGradient().putScalar(w2.getIndex(), - table.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()) + w1BiasHistory); - - } - - public INDArray getW1History() { - return w1History; - } - - public void setW1History(INDArray w1History) { - this.w1History = w1History; - } - - public INDArray getW2History() { - return w2History; - } - - public void setW2History(INDArray w2History) { - this.w2History = w2History; - } - - public double getW1BiasHistory() { - return w1BiasHistory; - } - - public void setW1BiasHistory(double w1BiasHistory) { - this.w1BiasHistory = w1BiasHistory; - } - - public double getW2BiasHistory() { - return w2BiasHistory; - } - - public void setW2BiasHistory(double w2BiasHistory) { - this.w2BiasHistory = w2BiasHistory; - } - - public VocabWord getW1() { - return w1; - } - - public void setW1(VocabWord w1) { - this.w1 = w1; - } - - public VocabWord getW2() { - return w2; - } - - public void setW2(VocabWord w2) { - this.w2 = w2; - } - - public INDArray getW1Update() { - return w1Update; - } - - public void setW1Update(INDArray w1Update) { - this.w1Update = w1Update; - } - - public INDArray getW2Update() { - return w2Update; - } - - public void setW2Update(INDArray w2Update) { - this.w2Update = w2Update; - } - - public double getW1BiasUpdate() { - return w1BiasUpdate; - } - - public void setW1BiasUpdate(double w1BiasUpdate) { - this.w1BiasUpdate = w1BiasUpdate; - } - - public double getW2BiasUpdate() { - return w2BiasUpdate; - } - - public void setW2BiasUpdate(double w2BiasUpdate) { - this.w2BiasUpdate = w2BiasUpdate; - } - - public double getError() { - return error; - } - - public void setError(double error) { - this.error = error; - } - - @Override - public String toString() { - return w1.getIndex() + "," + w2.getIndex() + " error " + error; - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/GloveParam.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/GloveParam.java deleted file mode 100644 index d66593322..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/GloveParam.java +++ /dev/null @@ -1,171 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.models.embeddings.glove; - -import org.apache.spark.broadcast.Broadcast; -import org.nd4j.linalg.api.rng.Random; -import org.nd4j.common.primitives.CounterMap; - -import java.io.Serializable; - -/** - * @author Adam Gibson - */ -public class GloveParam implements Serializable { - - private int vectorLength; - private boolean useAdaGrad; - private double lr; - private Random gen; - private double negative; - private double xMax; - private double maxCount; - private Broadcast> coOccurrenceCounts; - - public GloveParam(int vectorLength, boolean useAdaGrad, double lr, Random gen, double negative, double xMax, - double maxCount, Broadcast> coOccurrenceCounts) { - this.vectorLength = vectorLength; - this.useAdaGrad = useAdaGrad; - this.lr = lr; - this.gen = gen; - this.negative = negative; - this.xMax = xMax; - this.maxCount = maxCount; - this.coOccurrenceCounts = coOccurrenceCounts; - } - - public int getVectorLength() { - return vectorLength; - } - - public void setVectorLength(int vectorLength) { - this.vectorLength = vectorLength; - } - - public boolean isUseAdaGrad() { - return useAdaGrad; - } - - public void setUseAdaGrad(boolean useAdaGrad) { - this.useAdaGrad = useAdaGrad; - } - - public double getLr() { - return lr; - } - - public void setLr(double lr) { - this.lr = lr; - } - - public Random getGen() { - return gen; - } - - public void setGen(Random gen) { - this.gen = gen; - } - - public double getNegative() { - return negative; - } - - public void setNegative(double negative) { - this.negative = negative; - } - - public double getxMax() { - return xMax; - } - - public void setxMax(double xMax) { - this.xMax = xMax; - } - - public double getMaxCount() { - return maxCount; - } - - public void setMaxCount(double maxCount) { - this.maxCount = maxCount; - } - - public Broadcast> getCoOccurrenceCounts() { - return coOccurrenceCounts; - } - - public void setCoOccurrenceCounts(Broadcast> coOccurrenceCounts) { - this.coOccurrenceCounts = coOccurrenceCounts; - } - - - public static class Builder { - private int vectorLength = 300; - private boolean useAdaGrad = true; - private double lr = 0.025; - private Random gen; - private double negative = 5; - private double xMax = 0.75; - private double maxCount = 100; - private Broadcast> coOccurrenceCounts; - - public Builder vectorLength(int vectorLength) { - this.vectorLength = vectorLength; - return this; - } - - public Builder useAdaGrad(boolean useAdaGrad) { - this.useAdaGrad = useAdaGrad; - return this; - } - - public Builder lr(double lr) { - this.lr = lr; - return this; - } - - public Builder gen(Random gen) { - this.gen = gen; - return this; - } - - public Builder negative(double negative) { - this.negative = negative; - return this; - } - - public Builder xMax(double xMax) { - this.xMax = xMax; - return this; - } - - public Builder maxCount(double maxCount) { - this.maxCount = maxCount; - return this; - } - - public Builder coOccurrenceCounts(Broadcast> coOccurrenceCounts) { - this.coOccurrenceCounts = coOccurrenceCounts; - return this; - } - - public GloveParam build() { - return new GloveParam(vectorLength, useAdaGrad, lr, gen, negative, xMax, maxCount, coOccurrenceCounts); - } - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/GlovePerformer.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/GlovePerformer.java deleted file mode 100644 index bed7643c0..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/GlovePerformer.java +++ /dev/null @@ -1,48 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.models.embeddings.glove; - -import org.apache.spark.api.java.function.Function; -import org.deeplearning4j.models.glove.GloveWeightLookupTable; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.nd4j.common.primitives.Triple; - - -/** - * Base line glove performer - * - * @author Adam Gibson - */ -public class GlovePerformer implements Function, GloveChange> { - - - public final static String NAME_SPACE = "org.deeplearning4j.scaleout.perform.models.glove"; - public final static String VECTOR_LENGTH = NAME_SPACE + ".length"; - public final static String ALPHA = NAME_SPACE + ".alpha"; - public final static String X_MAX = NAME_SPACE + ".xmax"; - public final static String MAX_COUNT = NAME_SPACE + ".maxcount"; - private GloveWeightLookupTable table; - - public GlovePerformer(GloveWeightLookupTable table) { - this.table = table; - } - - @Override - public GloveChange call(Triple pair) throws Exception { - return null; - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/VocabWordPairs.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/VocabWordPairs.java deleted file mode 100644 index 7d6e09edb..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/VocabWordPairs.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.models.embeddings.glove; - -import org.apache.spark.api.java.function.Function; -import org.apache.spark.broadcast.Broadcast; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.nd4j.common.primitives.Triple; - -/** - * Convert string to vocab words - * - * @author Adam Gibson - */ -public class VocabWordPairs implements Function, Triple> { - private Broadcast> vocab; - - public VocabWordPairs(Broadcast> vocab) { - this.vocab = vocab; - } - - @Override - public Triple call(Triple v1) throws Exception { - return new Triple<>((VocabWord) vocab.getValue().wordFor(v1.getFirst()), - (VocabWord) vocab.getValue().wordFor(v1.getSecond()), v1.getThird()); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/cooccurrences/CoOccurrenceCalculator.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/cooccurrences/CoOccurrenceCalculator.java deleted file mode 100644 index 077a9c3b0..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/cooccurrences/CoOccurrenceCalculator.java +++ /dev/null @@ -1,91 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.models.embeddings.glove.cooccurrences; - -import org.apache.spark.api.java.function.Function; -import org.apache.spark.broadcast.Broadcast; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.CounterMap; -import org.nd4j.common.primitives.Pair; - -import java.util.List; -import java.util.concurrent.atomic.AtomicLong; - -/** - * Calculate co occurrences based on tokens - * - * @author Adam Gibson - */ -public class CoOccurrenceCalculator implements Function, AtomicLong>, CounterMap> { - private boolean symmetric = false; - private Broadcast> vocab; - private int windowSize = 5; - - public CoOccurrenceCalculator(boolean symmetric, Broadcast> vocab, int windowSize) { - this.symmetric = symmetric; - this.vocab = vocab; - this.windowSize = windowSize; - } - - - @Override - public CounterMap call(Pair, AtomicLong> pair) throws Exception { - List sentence = pair.getFirst(); - CounterMap coOCurreneCounts = new CounterMap<>(); - VocabCache vocab = this.vocab.value(); - for (int i = 0; i < sentence.size(); i++) { - int wordIdx = vocab.indexOf(sentence.get(i)); - String w1 = ((VocabWord) vocab.wordFor(sentence.get(i))).getWord(); - - if (wordIdx < 0) // || w1.equals(Glove.UNK)) - continue; - int windowStop = Math.min(i + windowSize + 1, sentence.size()); - for (int j = i; j < windowStop; j++) { - int otherWord = vocab.indexOf(sentence.get(j)); - String w2 = ((VocabWord) vocab.wordFor(sentence.get(j))).getWord(); - if (vocab.indexOf(sentence.get(j)) < 0) // || w2.equals(Glove.UNK)) - continue; - - if (otherWord == wordIdx) - continue; - if (wordIdx < otherWord) { - coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j), - (float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD))); - if (symmetric) - coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i), - (float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD))); - - - - } else { - float coCount = (float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD)); - coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i), (float) coCount); - if (symmetric) - coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j), - (float) (1.0 / (j - i + Nd4j.EPS_THRESHOLD))); - - - } - - - } - } - return coOCurreneCounts; - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/cooccurrences/CoOccurrenceCounts.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/cooccurrences/CoOccurrenceCounts.java deleted file mode 100644 index d79706ea5..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/glove/cooccurrences/CoOccurrenceCounts.java +++ /dev/null @@ -1,37 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.models.embeddings.glove.cooccurrences; - -import org.apache.spark.api.java.function.Function2; -import org.nd4j.common.primitives.CounterMap; - - -/** - * Co occurrence count reduction - * @author Adam Gibson - */ -public class CoOccurrenceCounts implements - Function2, CounterMap, CounterMap> { - - - @Override - public CounterMap call(CounterMap v1, CounterMap v2) - throws Exception { - v1.incrementAll(v2); - return v1; - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/glove/GloveTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/glove/GloveTest.java deleted file mode 100644 index 61ab27b54..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/glove/GloveTest.java +++ /dev/null @@ -1,62 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.models.embeddings.glove; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; -import org.nd4j.common.io.ClassPathResource; -import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; -import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; -import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; -import org.deeplearning4j.models.glove.GloveWeightLookupTable; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.deeplearning4j.spark.text.BaseSparkTest; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.common.primitives.Pair; - -import java.util.Collection; - -import static org.junit.Assert.assertTrue; - -/** - * Created by agibsonccc on 1/31/15. - */ -@Ignore -public class GloveTest extends BaseSparkTest { - - @Test - public void testGlove() throws Exception { - Glove glove = new Glove(true, 5, 100); - JavaRDD corpus = sc.textFile(new ClassPathResource("big/raw_sentences.txt").getFile().getAbsolutePath()) - .map(new Function() { - @Override - public String call(String s) throws Exception { - return s.toLowerCase(); - } - }); - - - Pair, GloveWeightLookupTable> table = glove.train(corpus); - WordVectors vectors = WordVectorSerializer - .fromPair(new Pair<>((InMemoryLookupTable) table.getSecond(), (VocabCache) table.getFirst())); - Collection words = vectors.wordsNearest("day", 20); - assertTrue(words.contains("week")); - } - -}